File size: 9,271 Bytes
22226e6
81ec1b1
 
 
 
22226e6
81ec1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22226e6
 
81ec1b1
 
 
22226e6
 
 
 
 
 
 
 
 
 
 
 
81ec1b1
 
 
 
 
 
 
 
22226e6
81ec1b1
 
22226e6
81ec1b1
 
 
86d3119
22226e6
 
81ec1b1
 
 
 
 
22226e6
 
 
81ec1b1
 
 
 
 
008983b
 
81ec1b1
 
22226e6
 
81ec1b1
 
 
 
 
 
 
 
22226e6
008983b
22226e6
81ec1b1
22226e6
 
 
 
 
 
 
 
81ec1b1
 
 
 
 
22226e6
 
81ec1b1
 
 
86d3119
81ec1b1
 
 
 
 
 
 
 
22226e6
 
 
 
 
 
81ec1b1
 
 
22226e6
81ec1b1
 
 
22226e6
81ec1b1
 
22226e6
 
 
 
 
81ec1b1
22226e6
81ec1b1
22226e6
 
81ec1b1
 
 
 
 
 
22226e6
81ec1b1
 
 
 
22226e6
 
81ec1b1
 
 
 
 
22226e6
81ec1b1
22226e6
 
81ec1b1
 
22226e6
 
 
 
 
 
81ec1b1
22226e6
 
 
81ec1b1
 
22226e6
81ec1b1
008983b
81ec1b1
22226e6
008983b
22226e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
008983b
22226e6
 
 
 
 
 
 
 
 
 
 
 
81ec1b1
008983b
 
 
 
22226e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import os
import time
import re
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
from streamlit_chat import message
from pathlib import Path
import torch
from PyPDF2 import PdfReader
import requests
from bs4 import BeautifulSoup

# Set page title and icon
st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:")

# Custom CSS for styling chat messages and buttons
st.markdown(
    """
    <style>
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        border-radius: 12px;
        padding: 10px 24px;
    }
    .stTextArea textarea {
        background-color: #f5f5f5;
    }
    .stDownloadButton>button {
        background-color: #4CAF50;
        color: white;
    }
    .stMessageContainer {
        border-radius: 15px;
        padding: 10px;
        margin: 10px 0;
    }
    .stMessage--user {
        background-color: #dfe7f3;
        border-left: 6px solid #006699;
    }
    .stMessage--assistant {
        background-color: #f3f3f3;
        border-left: 6px solid #4CAF50;
    }
    pre {
        background-color: #f5f5f5;
        border-left: 6px solid #dfe7f3;
        padding: 10px;
        font-size: 14px;
        border-radius: 8px;
    }
    </style>
    """,
    unsafe_allow_html=True,
)

# Initialize session state variables
if "generated" not in st.session_state:
    st.session_state["generated"] = []
if "past" not in st.session_state:
    st.session_state["past"] = []
if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
if "chat_data" not in st.session_state:
    st.session_state["chat_data"] = []  # For storing the chat logs
if "uploaded_docs" not in st.session_state:
    st.session_state["uploaded_docs"] = []  # For storing uploaded document content
if "web_data" not in st.session_state:
    st.session_state["web_data"] = []  # For storing web scraped data

# Sidebar - Model Selection, Style Parameters, and Cost Display
st.sidebar.title("Model Selection")
model_name = "gpt2"

# Parameters to adjust the response style and creativity
st.sidebar.title("Response Style Controls")
temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
max_length = st.sidebar.slider("Max Length", min_value=100, max_value=1024, value=800, step=10)

@st.cache_resource
def load_model_and_tokenizer():
    model_path = "gpt2"  # Path to the local model directory
    tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    return tokenizer, model

tokenizer, model = load_model_and_tokenizer()

def generate_response(prompt):
    """
    Generate a response using the GPT-2 model, including document and web data context.
    """
    context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
    inputs = tokenizer(context, return_tensors="pt")
    
    generation_config = {
        "max_length": max_length,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "pad_token_id": tokenizer.eos_token_id,
        "do_sample": True  # Always sample tokens
    }
    
    outputs = model.generate(
        inputs.input_ids,
        attention_mask=inputs.attention_mask,
        **generation_config
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Reset session
def reset_session():
    """ Reset all session state variables. """
    st.session_state["generated"] = []
    st.session_state["past"] = []
    st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
    st.session_state["chat_data"] = []  # Reset chat logs
    st.session_state["uploaded_docs"] = []  # Reset uploaded docs
    st.session_state["web_data"] = []  # Reset web data

reset_button = st.sidebar.button("Reset Chat")
if reset_button:
    reset_session()

def save_chat_data(chat_data):
    """ Save chat logs for future fine-tuning or reference. """
    with open("chat_data.json", "w") as f:
        json.dump(chat_data, f, indent=4)

def handle_uploaded_file(uploaded_file):
    dataset_dir = "./datasets"
    dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"

    # Check if the file is a PDF
    if uploaded_file.type == "application/pdf":
        pdf_reader = PdfReader(uploaded_file)
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text()

        if not text:
            st.error("Failed to extract text from the PDF.")
            return None  # Return None if text extraction fails

        with open(dataset_path, "w") as f:
            f.write(text)
        st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
    else:
        with open(dataset_path, "wb") as f:
            f.write(uploaded_file.getbuffer())
        st.success(f"File saved to {dataset_path}")

    return str(dataset_path)  # Return the path to the saved file

def handle_web_link(url):
    """ Fetch and scrape text content from a website. """
    try:
        response = requests.get(url)
        response.raise_for_status()
        soup = BeautifulSoup(response.content, "html.parser")
        text = soup.get_text()
        st.session_state["web_data"].append(text)
        st.success(f"Content from {url} saved successfully!")
    except requests.exceptions.RequestException as e:
        st.error(f"Failed to retrieve content: {e}")

st.sidebar.title("Add Website Links")
web_link = st.sidebar.text_input("Enter Website URL")
if web_link:
    handle_web_link(web_link)

# Chat interface
response_container = st.container()
container = st.container()

with container:
    with st.form(key="user_input_form"):
        user_input = st.text_area("You:", key="user_input", height=100)
        submit_button = st.form_submit_button("Send")

    if submit_button and user_input:
        start_time = time.time()
        output = generate_response(user_input)
        inference_time = time.time() - start_time

        st.session_state["past"].append(user_input)
        st.session_state["generated"].append(output)

        # Log chat data for future training
        st.session_state["chat_data"].append(
            {"user_input": user_input, "model_response": output}
        )

        save_chat_data(st.session_state["chat_data"])

        with response_container:
            for i in range(len(st.session_state["generated"])):
                message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
                message(st.session_state["generated"][i], key=str(i))

def fine_tune_model():
    uploaded_file_path = st.session_state.get("uploaded_file_path", "")
    if not uploaded_file_path:
        st.warning("Please upload a text or PDF dataset to fine-tune the model.")
        return

    # Prepare dataset for fine-tuning (using the uploaded .txt file)
    try:
        with open(uploaded_file_path, "r") as f:
            text = f.read().strip()  # Ensure that the file is not empty
            if len(text) == 0:
                raise ValueError("The dataset is empty.")
        train_dataset = TextDataset(
            tokenizer=tokenizer,
            file_path=uploaded_file_path,  # Ensure this path is a .txt file
            block_size=128,
        )
        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

        # Define training arguments
        training_args = TrainingArguments(
            output_dir="./gpt2-finetuned",
            overwrite_output_dir=True,
            num_train_epochs=3,
            per_device_train_batch_size=8,
            save_steps=10_000,
            save_total_limit=2,
            logging_dir="./logs",
            logging_steps=200,
        )

        # Initialize the Trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            data_collator=data_collator,
            train_dataset=train_dataset,
        )

        # Fine-tune the model
        trainer.train()
        st.success("Model fine-tuning completed successfully.")
    
    except Exception as e:
        st.error(f"Error during fine-tuning: {str(e)}")

# Sidebar file upload
st.sidebar.title("Upload Documents")
uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])

# Process uploaded file
if uploaded_file is not None:
    file_path = handle_uploaded_file(uploaded_file)
    if file_path:
        st.session_state["uploaded_file_path"] = file_path

# Add a button to trigger fine-tuning
st.sidebar.title("Fine-Tune Model")
fine_tune_button = st.sidebar.button("Fine-Tune GPT-2")
if fine_tune_button:
    fine_tune_model()