odaly commited on
Commit
008983b
·
verified ·
1 Parent(s): 8692970

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -136
app.py CHANGED
@@ -1,11 +1,11 @@
1
- import os
2
  import time
3
  import re
4
  import json
5
  import streamlit as st
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,AutoTokenizer, TextDataset, DataCollatorForLanguageModeling
7
  from streamlit_chat import message
8
- from datasets import load_dataset # تعديل لاستخدام مكتبة datasets
9
  from pathlib import Path
10
  import torch
11
  from PyPDF2 import PdfReader
@@ -53,31 +53,30 @@ st.markdown(
53
  border-radius: 8px;
54
  }
55
  </style>
56
- """,
57
- unsafe_allow_html=True,
58
  )
59
 
60
  # Initialize session state variables
61
- if "generated" not in st.session_state:
62
- st.session_state["generated"] = []
63
- if "past" not in st.session_state:
64
- st.session_state["past"] = []
65
- if "messages" not in st.session_state:
66
- st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
67
- if "model_name" not in st.session_state:
68
- st.session_state["model_name"] = []
69
- if "total_tokens" not in st.session_state:
70
- st.session_state["total_tokens"] = []
71
- if "total_cost" not in st.session_state:
72
- st.session_state["total_cost"] = 0.0
73
- if "chat_data" not in st.session_state:
74
- st.session_state["chat_data"] = [] # For storing the chat logs
75
- if "uploaded_docs" not in st.session_state:
76
- st.session_state["uploaded_docs"] = [] # For storing uploaded document content
77
- if "web_data" not in st.session_state:
78
- st.session_state["web_data"] = [] # For storing web scraped data
79
- if "uploaded_file_path" not in st.session_state:
80
- st.session_state["uploaded_file_path"] = "" # Store the path of saved files
81
 
82
  # Sidebar - Model Selection, Style Parameters, and Cost Display
83
  st.sidebar.title("Model Selection")
@@ -86,38 +85,41 @@ model_name = "gpt2"
86
  # Parameters to adjust the response style and creativity
87
  st.sidebar.title("Response Style Controls")
88
  temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
89
- top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
90
  top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
91
  repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
92
- max_length = st.sidebar.slider("Max Length", min_value=100, max_value=4048, value=400, step=10)
93
 
94
  # Load the model and tokenizer
95
  @st.cache_resource
96
  def load_model_and_tokenizer():
97
- model_path = "gpt2" # المسار المحلي للنموذج
98
- tokenizer = AutoTokenizer.from_pretrained(model_path, clean_up_tokenization_spaces=True)
99
- model = AutoModelForCausalLM.from_pretrained(model_path)
 
 
 
100
  return tokenizer, model
101
 
102
  tokenizer, model = load_model_and_tokenizer()
103
- # Function to generate a response using the model with updated generation configuration
104
- # إعداد متغيرات TrainingArguments مع تحسينات
105
- tokenizer.pad_token = tokenizer.eos_token # لضمان أن المفكرة تستخدم رمز eos كـ pad token
106
 
 
107
  def generate_response(prompt):
 
108
  context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
109
  inputs = tokenizer(context, return_tensors="pt")
110
 
 
111
  generation_config = {
112
  "max_length": max_length,
113
- "temperature": temperature if do_sample else None,
114
- "top_p": top_p if do_sample else None,
115
  "top_k": top_k,
116
  "repetition_penalty": repetition_penalty,
117
- "pad_token_id": tokenizer.eos_token_id,
118
- "do_sample": do_sample
119
  }
120
 
 
121
  outputs = model.generate(
122
  inputs.input_ids,
123
  attention_mask=inputs.attention_mask,
@@ -126,36 +128,32 @@ def generate_response(prompt):
126
 
127
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
128
  return response
129
- # Set do_sample to True
130
- do_sample = True
131
  # Function to reset the session
132
  def reset_session():
133
- st.session_state["generated"] = []
134
- st.session_state["past"] = []
135
- st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
136
- st.session_state["model_name"] = []
137
- st.session_state["total_tokens"] = []
138
- st.session_state["total_cost"] = 0.0
139
- st.session_state["chat_data"] = [] # Reset chat logs
140
- st.session_state["uploaded_docs"] = [] # Reset uploaded docs
141
- st.session_state["web_data"] = [] # Reset web data
142
-
143
 
144
  # Reset chat button in sidebar
145
  reset_button = st.sidebar.button("Reset Chat")
146
  if reset_button:
147
  reset_session()
148
 
149
-
150
  # Function to save chat logs for later fine-tuning
151
  def save_chat_data(chat_data):
152
- with open("chat_data.json", "w") as f:
153
  json.dump(chat_data, f, indent=4)
154
 
155
-
156
  # Function to handle uploaded text or PDF files and convert PDF to txt
157
  def handle_uploaded_file(uploaded_file):
158
- dataset_dir = "./datasets"
159
  dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
160
 
161
  # Check if the file is a PDF
@@ -165,19 +163,18 @@ def handle_uploaded_file(uploaded_file):
165
  text = ""
166
  for page in pdf_reader.pages:
167
  text += page.extract_text()
168
-
169
  # Save extracted text as a .txt file
170
- with open(dataset_path, "w") as f:
171
  f.write(text)
172
  st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
173
  else:
174
  # If it's a text file, save it as is
175
- with open(dataset_path, "wb") as f:
176
  f.write(uploaded_file.getbuffer())
177
  st.success(f"File saved to {dataset_path}")
178
-
179
- st.session_state["uploaded_file_path"] = str(dataset_path)
180
-
181
 
182
  # Add a file uploader for various formats
183
  st.sidebar.title("Upload Documents")
@@ -187,19 +184,17 @@ uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])
187
  if uploaded_file is not None:
188
  handle_uploaded_file(uploaded_file)
189
 
190
-
191
  # Function to fetch and scrape website content
192
  def handle_web_link(url):
193
  response = requests.get(url)
194
  if response.status_code == 200:
195
- soup = BeautifulSoup(response.content, "html.parser")
196
  text = soup.get_text()
197
- st.session_state["web_data"].append(text)
198
  st.success(f"Content from {url} saved successfully!")
199
  else:
200
  st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
201
 
202
-
203
  # Add a text box for entering website links
204
  st.sidebar.title("Add Website Links")
205
  web_link = st.sidebar.text_input("Enter Website URL")
@@ -208,14 +203,13 @@ web_link = st.sidebar.text_input("Enter Website URL")
208
  if web_link:
209
  handle_web_link(web_link)
210
 
211
-
212
  # Containers for chat history and user input
213
  response_container = st.container()
214
  container = st.container()
215
 
216
  with container:
217
- with st.form(key="user_input_form"):
218
- user_input = st.text_area("You:", key="user_input", height=100)
219
  submit_button = st.form_submit_button("Send")
220
 
221
  if submit_button and user_input:
@@ -225,97 +219,76 @@ with container:
225
  inference_time = end_time - start_time
226
 
227
  # Append user input and model output to session state
228
- st.session_state["past"].append(user_input)
229
- st.session_state["generated"].append(output)
230
- st.session_state["model_name"].append(model_name)
231
 
232
  # Log chat data for future training
233
- st.session_state["chat_data"].append(
234
- {"user_input": user_input, "model_response": output}
235
- )
 
236
 
237
  # Save chat data to a file (this could be used later for training)
238
- save_chat_data(st.session_state["chat_data"])
239
 
240
  # Calculate tokens and cost
241
-
 
 
 
 
242
 
243
  # Display chat history
244
  with response_container:
245
- for i in range(len(st.session_state["generated"])):
246
- message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
247
- message(st.session_state["generated"][i], key=str(i))
248
-
249
 
250
  # Function to fine-tune the model using uploaded dataset
251
  def fine_tune_model():
252
- uploaded_file_path = st.session_state.get("uploaded_file_path", None)
 
253
  if not uploaded_file_path:
254
- st.warning("يرجى تحميل dataset لتدريب النموذج.")
255
  return
256
-
257
- # تحميل البيانات النصية أو CSV
258
- if uploaded_file_path.endswith('.txt'):
259
- dataset = load_dataset('text', data_files=uploaded_file_path, split='train')
260
- elif uploaded_file_path.endswith('.csv'):
261
- dataset = load_dataset('csv', data_files=uploaded_file_path, split='train')
262
-
263
- # معالجة البيانات: تحويل النصوص إلى رموز (tokenization)
264
- def tokenize_function(examples):
265
- return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
266
-
267
- tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
268
-
269
- # إعداد الـ collator لعدم استخدام الـ mask language modeling
270
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
271
-
272
- # التحقق مما إذا كان النظام يستخدم GPU أم لا
273
- use_fp16 = torch.cuda.is_available() # تفعيل fp16 فقط إذا كان GPU متاحًا
274
-
275
- # إعداد متغيرات TrainingArguments
276
  training_args = TrainingArguments(
277
- output_dir='./gpt2-finetuned',
278
- overwrite_output_dir=True,
279
- num_train_epochs=4,
280
- per_device_train_batch_size=3,
281
- per_device_eval_batch_size=3,
282
- save_steps=500,
283
- eval_strategy="steps",
284
- eval_steps=500,
285
- learning_rate=2e-5,
286
- weight_decay=0.01,
287
- logging_dir='./logs',
288
- logging_steps=100,
289
- save_total_limit=3,
290
- load_best_model_at_end=True,
291
- metric_for_best_model='accuracy',
292
- greater_is_better=True,
293
- fp16=use_fp16, # تفعيل fp16 فقط إذا كان GPU متاحًا
294
- remove_unused_columns=False, # تعطيل هذا الخيار لحل مشكلة عدم توافق الأعمدة
295
  )
296
-
297
- # تهيئة الـ Trainer
298
  trainer = Trainer(
299
  model=model,
300
  args=training_args,
301
  data_collator=data_collator,
302
- train_dataset=tokenized_dataset,
303
  )
304
-
305
- # البدء في التدريب
306
  trainer.train()
 
 
307
 
308
- st.success("تم إكمال تدريب النموذج بنجاح.")
309
-
310
- # واجهة Streamlit لتحميل dataset وبدء التدريب
311
- st.title("Fine-tune GPT-2 Model")
312
-
313
- uploaded_file = st.file_uploader("Upload your dataset (TXT or CSV)", type=['txt', 'csv'])
314
- if uploaded_file:
315
- st.session_state["uploaded_file_path"] = uploaded_file.name
316
- with open(uploaded_file.name, "wb") as f:
317
- f.write(uploaded_file.getbuffer())
318
- st.success(f"File {uploaded_file.name} uploaded successfully.")
319
-
320
- if st.button("Start Fine-tuning"):
321
- fine_tune_model()
 
1
+ import os
2
  import time
3
  import re
4
  import json
5
  import streamlit as st
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,TextDataset, DataCollatorForLanguageModeling
7
  from streamlit_chat import message
8
+
9
  from pathlib import Path
10
  import torch
11
  from PyPDF2 import PdfReader
 
53
  border-radius: 8px;
54
  }
55
  </style>
56
+ """, unsafe_allow_html=True
 
57
  )
58
 
59
  # Initialize session state variables
60
+ if 'generated' not in st.session_state:
61
+ st.session_state['generated'] = []
62
+ if 'past' not in st.session_state:
63
+ st.session_state['past'] = []
64
+ if 'messages' not in st.session_state:
65
+ st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
66
+ if 'model_name' not in st.session_state:
67
+ st.session_state['model_name'] = []
68
+ if 'total_tokens' not in st.session_state:
69
+ st.session_state['total_tokens'] = []
70
+ if 'total_cost' not in st.session_state:
71
+ st.session_state['total_cost'] = 0.0
72
+ if 'chat_data' not in st.session_state:
73
+ st.session_state['chat_data'] = [] # For storing the chat logs
74
+ if 'uploaded_docs' not in st.session_state:
75
+ st.session_state['uploaded_docs'] = [] # For storing uploaded document content
76
+ if 'web_data' not in st.session_state:
77
+ st.session_state['web_data'] = [] # For storing web scraped data
78
+ if 'uploaded_file_path' not in st.session_state:
79
+ st.session_state['uploaded_file_path'] = "" # Store the path of saved files
80
 
81
  # Sidebar - Model Selection, Style Parameters, and Cost Display
82
  st.sidebar.title("Model Selection")
 
85
  # Parameters to adjust the response style and creativity
86
  st.sidebar.title("Response Style Controls")
87
  temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
88
+ top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.5, max_value=1.0, value=0.5, step=0.05)
89
  top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
90
  repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
91
+ max_length = st.sidebar.slider("Max Length", min_value=50, max_value=1024, value=200, step=10)
92
 
93
  # Load the model and tokenizer
94
  @st.cache_resource
95
  def load_model_and_tokenizer():
96
+ model_path = "C:/Users/MC/Ollama_UI/models--gpt2" # Path to the local model directory
97
+
98
+ # Load model and tokenizer
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path) # Use GPT-2 tokenizer from Hugging Face
100
+ model = AutoModelForCausalLM.from_pretrained(model_path) # Load the model from the local directory
101
+
102
  return tokenizer, model
103
 
104
  tokenizer, model = load_model_and_tokenizer()
 
 
 
105
 
106
+ # Function to generate a response using the model with updated generation configuration
107
  def generate_response(prompt):
108
+ # Combine user input with document and web data context
109
  context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
110
  inputs = tokenizer(context, return_tensors="pt")
111
 
112
+ # Define generation configuration
113
  generation_config = {
114
  "max_length": max_length,
115
+ "temperature": temperature,
116
+ "top_p": top_p,
117
  "top_k": top_k,
118
  "repetition_penalty": repetition_penalty,
119
+ "pad_token_id": tokenizer.eos_token_id
 
120
  }
121
 
122
+ # Pass attention_mask and generate the output
123
  outputs = model.generate(
124
  inputs.input_ids,
125
  attention_mask=inputs.attention_mask,
 
128
 
129
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
130
  return response
131
+
 
132
  # Function to reset the session
133
  def reset_session():
134
+ st.session_state['generated'] = []
135
+ st.session_state['past'] = []
136
+ st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
137
+ st.session_state['model_name'] = []
138
+ st.session_state['total_tokens'] = []
139
+ st.session_state['total_cost'] = 0.0
140
+ st.session_state['chat_data'] = [] # Reset chat logs
141
+ st.session_state['uploaded_docs'] = [] # Reset uploaded docs
142
+ st.session_state['web_data'] = [] # Reset web data
 
143
 
144
  # Reset chat button in sidebar
145
  reset_button = st.sidebar.button("Reset Chat")
146
  if reset_button:
147
  reset_session()
148
 
 
149
  # Function to save chat logs for later fine-tuning
150
  def save_chat_data(chat_data):
151
+ with open('chat_data.json', 'w') as f:
152
  json.dump(chat_data, f, indent=4)
153
 
 
154
  # Function to handle uploaded text or PDF files and convert PDF to txt
155
  def handle_uploaded_file(uploaded_file):
156
+ dataset_dir = "C:/Users/MC/Ollama_UI/datasets"
157
  dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
158
 
159
  # Check if the file is a PDF
 
163
  text = ""
164
  for page in pdf_reader.pages:
165
  text += page.extract_text()
166
+
167
  # Save extracted text as a .txt file
168
+ with open(dataset_path, 'w') as f:
169
  f.write(text)
170
  st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
171
  else:
172
  # If it's a text file, save it as is
173
+ with open(dataset_path, 'wb') as f:
174
  f.write(uploaded_file.getbuffer())
175
  st.success(f"File saved to {dataset_path}")
176
+
177
+ st.session_state['uploaded_file_path'] = str(dataset_path)
 
178
 
179
  # Add a file uploader for various formats
180
  st.sidebar.title("Upload Documents")
 
184
  if uploaded_file is not None:
185
  handle_uploaded_file(uploaded_file)
186
 
 
187
  # Function to fetch and scrape website content
188
  def handle_web_link(url):
189
  response = requests.get(url)
190
  if response.status_code == 200:
191
+ soup = BeautifulSoup(response.content, 'html.parser')
192
  text = soup.get_text()
193
+ st.session_state['web_data'].append(text)
194
  st.success(f"Content from {url} saved successfully!")
195
  else:
196
  st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
197
 
 
198
  # Add a text box for entering website links
199
  st.sidebar.title("Add Website Links")
200
  web_link = st.sidebar.text_input("Enter Website URL")
 
203
  if web_link:
204
  handle_web_link(web_link)
205
 
 
206
  # Containers for chat history and user input
207
  response_container = st.container()
208
  container = st.container()
209
 
210
  with container:
211
+ with st.form(key='user_input_form'):
212
+ user_input = st.text_area("You:", key='user_input', height=100)
213
  submit_button = st.form_submit_button("Send")
214
 
215
  if submit_button and user_input:
 
219
  inference_time = end_time - start_time
220
 
221
  # Append user input and model output to session state
222
+ st.session_state['past'].append(user_input)
223
+ st.session_state['generated'].append(output)
224
+ st.session_state['model_name'].append(model_name)
225
 
226
  # Log chat data for future training
227
+ st.session_state['chat_data'].append({
228
+ "user_input": user_input,
229
+ "model_response": output
230
+ })
231
 
232
  # Save chat data to a file (this could be used later for training)
233
+ save_chat_data(st.session_state['chat_data'])
234
 
235
  # Calculate tokens and cost
236
+ num_tokens = len(tokenizer.encode(user_input)) + len(tokenizer.encode(output))
237
+ st.session_state['total_tokens'].append(num_tokens)
238
+ cost_per_1000_tokens = 0.0001
239
+ cost = cost_per_1000_tokens * (num_tokens / 1000)
240
+ st.session_state['total_cost'] += cost
241
 
242
  # Display chat history
243
  with response_container:
244
+ for i in range(len(st.session_state['generated'])):
245
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
246
+ message(st.session_state['generated'][i], key=str(i))
247
+
248
 
249
  # Function to fine-tune the model using uploaded dataset
250
  def fine_tune_model():
251
+ # Check if a dataset has been uploaded
252
+ uploaded_file_path = st.session_state['uploaded_file_path']
253
  if not uploaded_file_path:
254
+ st.warning("Please upload a text or PDF dataset to fine-tune the model.")
255
  return
256
+
257
+ # Prepare dataset for fine-tuning (using the uploaded .txt file)
258
+ train_dataset = TextDataset(
259
+ tokenizer=tokenizer,
260
+ file_path=uploaded_file_path, # Ensure this path is a .txt file
261
+ block_size=128
262
+ )
 
 
 
 
 
 
 
263
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
264
+
265
+ # Define training arguments
 
 
 
266
  training_args = TrainingArguments(
267
+ output_dir='./gpt2-finetuned',
268
+ overwrite_output_dir=True,
269
+ num_train_epochs=3,
270
+ per_device_train_batch_size=8,
271
+ save_steps=10_000,
272
+ save_total_limit=2,
273
+ logging_dir='./logs',
274
+ logging_steps=200,
 
 
 
 
 
 
 
 
 
 
275
  )
276
+
277
+ # Initialize the Trainer
278
  trainer = Trainer(
279
  model=model,
280
  args=training_args,
281
  data_collator=data_collator,
282
+ train_dataset=train_dataset
283
  )
284
+
285
+ # Fine-tune the model
286
  trainer.train()
287
+
288
+ st.success("Model fine-tuning completed successfully.")
289
 
290
+ # Add a button to trigger fine-tuning
291
+ st.sidebar.title("Fine-Tune Model")
292
+ fine_tune_button = st.sidebar.button("Fine-Tune GPT-2")
293
+ if fine_tune_button:
294
+ fine_tune_model()