odaly commited on
Commit
22226e6
·
verified ·
1 Parent(s): 86d3119

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -143
app.py CHANGED
@@ -1,11 +1,10 @@
1
- import os
2
  import time
3
  import re
4
  import json
5
  import streamlit as st
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,TextDataset, DataCollatorForLanguageModeling
7
  from streamlit_chat import message
8
-
9
  from pathlib import Path
10
  import torch
11
  from PyPDF2 import PdfReader
@@ -53,30 +52,23 @@ st.markdown(
53
  border-radius: 8px;
54
  }
55
  </style>
56
- """, unsafe_allow_html=True
 
57
  )
58
 
59
  # Initialize session state variables
60
- if 'generated' not in st.session_state:
61
- st.session_state['generated'] = []
62
- if 'past' not in st.session_state:
63
- st.session_state['past'] = []
64
- if 'messages' not in st.session_state:
65
- st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
66
- if 'model_name' not in st.session_state:
67
- st.session_state['model_name'] = []
68
- if 'total_tokens' not in st.session_state:
69
- st.session_state['total_tokens'] = []
70
- if 'total_cost' not in st.session_state:
71
- st.session_state['total_cost'] = 0.0
72
- if 'chat_data' not in st.session_state:
73
- st.session_state['chat_data'] = [] # For storing the chat logs
74
- if 'uploaded_docs' not in st.session_state:
75
- st.session_state['uploaded_docs'] = [] # For storing uploaded document content
76
- if 'web_data' not in st.session_state:
77
- st.session_state['web_data'] = [] # For storing web scraped data
78
- if 'uploaded_file_path' not in st.session_state:
79
- st.session_state['uploaded_file_path'] = "" # Store the path of saved files
80
 
81
  # Sidebar - Model Selection, Style Parameters, and Cost Display
82
  st.sidebar.title("Model Selection")
@@ -85,210 +77,193 @@ model_name = "gpt2"
85
  # Parameters to adjust the response style and creativity
86
  st.sidebar.title("Response Style Controls")
87
  temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
88
- top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.5, max_value=1.0, value=0.5, step=0.05)
89
  top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
90
  repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
91
- max_length = st.sidebar.slider("Max Length", min_value=50, max_value=1024, value=200, step=10)
92
 
93
- # Load the model and tokenizer
94
  @st.cache_resource
95
  def load_model_and_tokenizer():
96
  model_path = "gpt2" # Path to the local model directory
97
-
98
- # Load model and tokenizer
99
- tokenizer = AutoTokenizer.from_pretrained(model_path) # Use GPT-2 tokenizer from Hugging Face
100
- model = AutoModelForCausalLM.from_pretrained(model_path) # Load the model from the local directory
101
-
102
  return tokenizer, model
103
 
104
  tokenizer, model = load_model_and_tokenizer()
105
 
106
- # Function to generate a response using the model with updated generation configuration
107
  def generate_response(prompt):
108
- # Combine user input with document and web data context
 
 
109
  context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
110
  inputs = tokenizer(context, return_tensors="pt")
111
 
112
- # Define generation configuration
113
  generation_config = {
114
  "max_length": max_length,
115
  "temperature": temperature,
116
  "top_p": top_p,
117
  "top_k": top_k,
118
  "repetition_penalty": repetition_penalty,
119
- "pad_token_id": tokenizer.eos_token_id
 
120
  }
121
 
122
- # Pass attention_mask and generate the output
123
  outputs = model.generate(
124
  inputs.input_ids,
125
  attention_mask=inputs.attention_mask,
126
  **generation_config
127
  )
128
 
129
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
130
- return response
131
 
132
- # Function to reset the session
133
  def reset_session():
134
- st.session_state['generated'] = []
135
- st.session_state['past'] = []
136
- st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
137
- st.session_state['model_name'] = []
138
- st.session_state['total_tokens'] = []
139
- st.session_state['total_cost'] = 0.0
140
- st.session_state['chat_data'] = [] # Reset chat logs
141
- st.session_state['uploaded_docs'] = [] # Reset uploaded docs
142
- st.session_state['web_data'] = [] # Reset web data
143
-
144
- # Reset chat button in sidebar
145
  reset_button = st.sidebar.button("Reset Chat")
146
  if reset_button:
147
  reset_session()
148
 
149
- # Function to save chat logs for later fine-tuning
150
  def save_chat_data(chat_data):
151
- with open('chat_data.json', 'w') as f:
 
152
  json.dump(chat_data, f, indent=4)
153
 
154
- # Function to handle uploaded text or PDF files and convert PDF to txt
155
  def handle_uploaded_file(uploaded_file):
156
  dataset_dir = "./datasets"
157
  dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
158
 
159
  # Check if the file is a PDF
160
  if uploaded_file.type == "application/pdf":
161
- # Read and extract text from the PDF
162
  pdf_reader = PdfReader(uploaded_file)
163
  text = ""
164
  for page in pdf_reader.pages:
165
  text += page.extract_text()
166
-
167
- # Save extracted text as a .txt file
168
- with open(dataset_path, 'w') as f:
 
 
 
169
  f.write(text)
170
  st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
171
  else:
172
- # If it's a text file, save it as is
173
- with open(dataset_path, 'wb') as f:
174
  f.write(uploaded_file.getbuffer())
175
  st.success(f"File saved to {dataset_path}")
176
-
177
- st.session_state['uploaded_file_path'] = str(dataset_path)
178
 
179
- # Add a file uploader for various formats
180
- st.sidebar.title("Upload Documents")
181
- uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])
182
-
183
- # Process uploaded file
184
- if uploaded_file is not None:
185
- handle_uploaded_file(uploaded_file)
186
 
187
- # Function to fetch and scrape website content
188
  def handle_web_link(url):
189
- response = requests.get(url)
190
- if response.status_code == 200:
191
- soup = BeautifulSoup(response.content, 'html.parser')
 
 
192
  text = soup.get_text()
193
- st.session_state['web_data'].append(text)
194
  st.success(f"Content from {url} saved successfully!")
195
- else:
196
- st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
197
 
198
- # Add a text box for entering website links
199
  st.sidebar.title("Add Website Links")
200
  web_link = st.sidebar.text_input("Enter Website URL")
201
-
202
- # Process web link
203
  if web_link:
204
  handle_web_link(web_link)
205
 
206
- # Containers for chat history and user input
207
  response_container = st.container()
208
  container = st.container()
209
 
210
  with container:
211
- with st.form(key='user_input_form'):
212
- user_input = st.text_area("You:", key='user_input', height=100)
213
  submit_button = st.form_submit_button("Send")
214
 
215
  if submit_button and user_input:
216
  start_time = time.time()
217
  output = generate_response(user_input)
218
- end_time = time.time()
219
- inference_time = end_time - start_time
220
 
221
- # Append user input and model output to session state
222
- st.session_state['past'].append(user_input)
223
- st.session_state['generated'].append(output)
224
- st.session_state['model_name'].append(model_name)
225
 
226
  # Log chat data for future training
227
- st.session_state['chat_data'].append({
228
- "user_input": user_input,
229
- "model_response": output
230
- })
231
-
232
- # Save chat data to a file (this could be used later for training)
233
- save_chat_data(st.session_state['chat_data'])
234
-
235
- # Calculate tokens and cost
236
- num_tokens = len(tokenizer.encode(user_input)) + len(tokenizer.encode(output))
237
- st.session_state['total_tokens'].append(num_tokens)
238
- cost_per_1000_tokens = 0.0001
239
- cost = cost_per_1000_tokens * (num_tokens / 1000)
240
- st.session_state['total_cost'] += cost
241
-
242
- # Display chat history
243
  with response_container:
244
- for i in range(len(st.session_state['generated'])):
245
- message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
246
- message(st.session_state['generated'][i], key=str(i))
247
-
248
 
249
- # Function to fine-tune the model using uploaded dataset
250
  def fine_tune_model():
251
- # Check if a dataset has been uploaded
252
- uploaded_file_path = st.session_state['uploaded_file_path']
253
  if not uploaded_file_path:
254
  st.warning("Please upload a text or PDF dataset to fine-tune the model.")
255
  return
256
-
257
  # Prepare dataset for fine-tuning (using the uploaded .txt file)
258
- train_dataset = TextDataset(
259
- tokenizer=tokenizer,
260
- file_path=uploaded_file_path, # Ensure this path is a .txt file
261
- block_size=128
262
- )
263
- data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
264
-
265
- # Define training arguments
266
- training_args = TrainingArguments(
267
- output_dir='./gpt2-finetuned',
268
- overwrite_output_dir=True,
269
- num_train_epochs=3,
270
- per_device_train_batch_size=8,
271
- save_steps=10_000,
272
- save_total_limit=2,
273
- logging_dir='./logs',
274
- logging_steps=200,
275
- )
276
-
277
- # Initialize the Trainer
278
- trainer = Trainer(
279
- model=model,
280
- args=training_args,
281
- data_collator=data_collator,
282
- train_dataset=train_dataset
283
- )
284
-
285
- # Fine-tune the model
286
- trainer.train()
 
 
 
 
 
 
287
 
288
- st.success("Model fine-tuning completed successfully.")
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  # Add a button to trigger fine-tuning
291
  st.sidebar.title("Fine-Tune Model")
292
  fine_tune_button = st.sidebar.button("Fine-Tune GPT-2")
293
  if fine_tune_button:
294
- fine_tune_model()
 
1
+ import os
2
  import time
3
  import re
4
  import json
5
  import streamlit as st
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
7
  from streamlit_chat import message
 
8
  from pathlib import Path
9
  import torch
10
  from PyPDF2 import PdfReader
 
52
  border-radius: 8px;
53
  }
54
  </style>
55
+ """,
56
+ unsafe_allow_html=True,
57
  )
58
 
59
  # Initialize session state variables
60
+ if "generated" not in st.session_state:
61
+ st.session_state["generated"] = []
62
+ if "past" not in st.session_state:
63
+ st.session_state["past"] = []
64
+ if "messages" not in st.session_state:
65
+ st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
66
+ if "chat_data" not in st.session_state:
67
+ st.session_state["chat_data"] = [] # For storing the chat logs
68
+ if "uploaded_docs" not in st.session_state:
69
+ st.session_state["uploaded_docs"] = [] # For storing uploaded document content
70
+ if "web_data" not in st.session_state:
71
+ st.session_state["web_data"] = [] # For storing web scraped data
 
 
 
 
 
 
 
 
72
 
73
  # Sidebar - Model Selection, Style Parameters, and Cost Display
74
  st.sidebar.title("Model Selection")
 
77
  # Parameters to adjust the response style and creativity
78
  st.sidebar.title("Response Style Controls")
79
  temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
80
+ top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
81
  top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
82
  repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
83
+ max_length = st.sidebar.slider("Max Length", min_value=100, max_value=1024, value=800, step=10)
84
 
 
85
  @st.cache_resource
86
  def load_model_and_tokenizer():
87
  model_path = "gpt2" # Path to the local model directory
88
+ tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)
89
+ model = AutoModelForCausalLM.from_pretrained(model_path)
 
 
 
90
  return tokenizer, model
91
 
92
  tokenizer, model = load_model_and_tokenizer()
93
 
 
94
  def generate_response(prompt):
95
+ """
96
+ Generate a response using the GPT-2 model, including document and web data context.
97
+ """
98
  context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
99
  inputs = tokenizer(context, return_tensors="pt")
100
 
 
101
  generation_config = {
102
  "max_length": max_length,
103
  "temperature": temperature,
104
  "top_p": top_p,
105
  "top_k": top_k,
106
  "repetition_penalty": repetition_penalty,
107
+ "pad_token_id": tokenizer.eos_token_id,
108
+ "do_sample": True # Always sample tokens
109
  }
110
 
 
111
  outputs = model.generate(
112
  inputs.input_ids,
113
  attention_mask=inputs.attention_mask,
114
  **generation_config
115
  )
116
 
117
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
118
 
119
+ # Reset session
120
  def reset_session():
121
+ """ Reset all session state variables. """
122
+ st.session_state["generated"] = []
123
+ st.session_state["past"] = []
124
+ st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
125
+ st.session_state["chat_data"] = [] # Reset chat logs
126
+ st.session_state["uploaded_docs"] = [] # Reset uploaded docs
127
+ st.session_state["web_data"] = [] # Reset web data
128
+
 
 
 
129
  reset_button = st.sidebar.button("Reset Chat")
130
  if reset_button:
131
  reset_session()
132
 
 
133
  def save_chat_data(chat_data):
134
+ """ Save chat logs for future fine-tuning or reference. """
135
+ with open("chat_data.json", "w") as f:
136
  json.dump(chat_data, f, indent=4)
137
 
 
138
  def handle_uploaded_file(uploaded_file):
139
  dataset_dir = "./datasets"
140
  dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
141
 
142
  # Check if the file is a PDF
143
  if uploaded_file.type == "application/pdf":
 
144
  pdf_reader = PdfReader(uploaded_file)
145
  text = ""
146
  for page in pdf_reader.pages:
147
  text += page.extract_text()
148
+
149
+ if not text:
150
+ st.error("Failed to extract text from the PDF.")
151
+ return None # Return None if text extraction fails
152
+
153
+ with open(dataset_path, "w") as f:
154
  f.write(text)
155
  st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
156
  else:
157
+ with open(dataset_path, "wb") as f:
 
158
  f.write(uploaded_file.getbuffer())
159
  st.success(f"File saved to {dataset_path}")
 
 
160
 
161
+ return str(dataset_path) # Return the path to the saved file
 
 
 
 
 
 
162
 
 
163
  def handle_web_link(url):
164
+ """ Fetch and scrape text content from a website. """
165
+ try:
166
+ response = requests.get(url)
167
+ response.raise_for_status()
168
+ soup = BeautifulSoup(response.content, "html.parser")
169
  text = soup.get_text()
170
+ st.session_state["web_data"].append(text)
171
  st.success(f"Content from {url} saved successfully!")
172
+ except requests.exceptions.RequestException as e:
173
+ st.error(f"Failed to retrieve content: {e}")
174
 
 
175
  st.sidebar.title("Add Website Links")
176
  web_link = st.sidebar.text_input("Enter Website URL")
 
 
177
  if web_link:
178
  handle_web_link(web_link)
179
 
180
+ # Chat interface
181
  response_container = st.container()
182
  container = st.container()
183
 
184
  with container:
185
+ with st.form(key="user_input_form"):
186
+ user_input = st.text_area("You:", key="user_input", height=100)
187
  submit_button = st.form_submit_button("Send")
188
 
189
  if submit_button and user_input:
190
  start_time = time.time()
191
  output = generate_response(user_input)
192
+ inference_time = time.time() - start_time
 
193
 
194
+ st.session_state["past"].append(user_input)
195
+ st.session_state["generated"].append(output)
 
 
196
 
197
  # Log chat data for future training
198
+ st.session_state["chat_data"].append(
199
+ {"user_input": user_input, "model_response": output}
200
+ )
201
+
202
+ save_chat_data(st.session_state["chat_data"])
203
+
 
 
 
 
 
 
 
 
 
 
204
  with response_container:
205
+ for i in range(len(st.session_state["generated"])):
206
+ message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
207
+ message(st.session_state["generated"][i], key=str(i))
 
208
 
 
209
  def fine_tune_model():
210
+ uploaded_file_path = st.session_state.get("uploaded_file_path", "")
 
211
  if not uploaded_file_path:
212
  st.warning("Please upload a text or PDF dataset to fine-tune the model.")
213
  return
214
+
215
  # Prepare dataset for fine-tuning (using the uploaded .txt file)
216
+ try:
217
+ with open(uploaded_file_path, "r") as f:
218
+ text = f.read().strip() # Ensure that the file is not empty
219
+ if len(text) == 0:
220
+ raise ValueError("The dataset is empty.")
221
+ train_dataset = TextDataset(
222
+ tokenizer=tokenizer,
223
+ file_path=uploaded_file_path, # Ensure this path is a .txt file
224
+ block_size=128,
225
+ )
226
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
227
+
228
+ # Define training arguments
229
+ training_args = TrainingArguments(
230
+ output_dir="./gpt2-finetuned",
231
+ overwrite_output_dir=True,
232
+ num_train_epochs=3,
233
+ per_device_train_batch_size=8,
234
+ save_steps=10_000,
235
+ save_total_limit=2,
236
+ logging_dir="./logs",
237
+ logging_steps=200,
238
+ )
239
+
240
+ # Initialize the Trainer
241
+ trainer = Trainer(
242
+ model=model,
243
+ args=training_args,
244
+ data_collator=data_collator,
245
+ train_dataset=train_dataset,
246
+ )
247
+
248
+ # Fine-tune the model
249
+ trainer.train()
250
+ st.success("Model fine-tuning completed successfully.")
251
 
252
+ except Exception as e:
253
+ st.error(f"Error during fine-tuning: {str(e)}")
254
+
255
+ # Sidebar file upload
256
+ st.sidebar.title("Upload Documents")
257
+ uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])
258
+
259
+ # Process uploaded file
260
+ if uploaded_file is not None:
261
+ file_path = handle_uploaded_file(uploaded_file)
262
+ if file_path:
263
+ st.session_state["uploaded_file_path"] = file_path
264
 
265
  # Add a button to trigger fine-tuning
266
  st.sidebar.title("Fine-Tune Model")
267
  fine_tune_button = st.sidebar.button("Fine-Tune GPT-2")
268
  if fine_tune_button:
269
+ fine_tune_model()