odaly commited on
Commit
3dfe622
·
verified ·
1 Parent(s): 4472eb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -126
app.py CHANGED
@@ -1,17 +1,21 @@
1
  import os
2
  import time
3
  import re
4
- import requests
5
  import json
6
- from bs4 import BeautifulSoup
7
  import streamlit as st
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, LlamaConfig
9
  from streamlit_chat import message
 
 
 
 
 
 
10
 
11
  # Set page title and icon
12
- st.set_page_config(page_title="LLaMA Chatbot", page_icon=":robot_face:")
13
 
14
- # Custom CSS for styling
15
  st.markdown(
16
  """
17
  <style>
@@ -23,146 +27,195 @@ st.markdown(
23
  }
24
  .stTextArea textarea {
25
  background-color: #f5f5f5;
26
- color: red;
27
  }
28
  .stDownloadButton>button {
29
  background-color: #4CAF50;
30
- color: black;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
  </style>
33
- """, unsafe_allow_html=True
 
34
  )
35
 
36
- # Load Hugging Face API token
37
-
38
  # Initialize session state variables
39
- if 'generated' not in st.session_state:
40
- st.session_state['generated'] = []
41
- if 'past' not in st.session_state:
42
- st.session_state['past'] = []
43
- if 'messages' not in st.session_state:
44
- st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
45
- if 'model_name' not in st.session_state:
46
- st.session_state['model_name'] = []
47
- if 'total_tokens' not in st.session_state:
48
- st.session_state['total_tokens'] = []
49
- if 'total_cost' not in st.session_state:
50
- st.session_state['total_cost'] = 0.0
51
- if 'chat_data' not in st.session_state:
52
- st.session_state['chat_data'] = [] # For storing the chat logs
 
 
 
 
 
 
53
 
54
  # Sidebar - Model Selection, Style Parameters, and Cost Display
55
  st.sidebar.title("Model Selection")
56
- model_name = st.sidebar.selectbox("Choose a model:", ["gpt2", "gpt-neo-125M", "distilgpt2", "LLaMA"])
57
 
58
  # Parameters to adjust the response style and creativity
59
  st.sidebar.title("Response Style Controls")
60
- temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.5, step=0.1)
61
- top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.5, max_value=1.0, value=0.5, step=0.05)
62
  top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
63
  repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
64
- max_length = st.sidebar.slider("Max Length", min_value=50, max_value=1024, value=500, step=10)
65
 
66
- # Function to load the model and tokenizer
67
  @st.cache_resource
68
- def load_model_and_tokenizer(model_name):
69
- if "LLaMA" in model_name:
70
- tokenizer = LlamaTokenizer.from_pretrained(model_name)
71
- config = LlamaConfig.from_pretrained(model_name)
72
- model = LlamaForCausalLM.from_pretrained(model_name, config=config)
73
- else:
74
- tokenizer = AutoTokenizer.from_pretrained(model_name)
75
- model = AutoModelForCausalLM.from_pretrained(model_name)
76
-
77
  return tokenizer, model
78
 
79
- tokenizer, model = load_model_and_tokenizer(model_name)
 
 
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Function to reset the session
82
  def reset_session():
83
- st.session_state['generated'] = []
84
- st.session_state['past'] = []
85
- st.session_state['messages'] = [{"role": "system", "content": "You are a helpful assistant."}]
86
- st.session_state['model_name'] = []
87
- st.session_state['total_tokens'] = []
88
- st.session_state['total_cost'] = 0.0
89
- st.session_state['chat_data'] = [] # Reset chat logs
 
 
 
90
 
91
  # Reset chat button in sidebar
92
  reset_button = st.sidebar.button("Reset Chat")
93
  if reset_button:
94
  reset_session()
95
 
96
- # Function to fetch and parse a webpage for specific tags
97
- def fetch_website_content(url):
98
- try:
99
- response = requests.get(url)
100
- if response.status_code == 200:
101
- soup = BeautifulSoup(response.text, 'html.parser')
102
- headings = [h.get_text() for h in soup.find_all(['h1', 'h2', 'h3'])]
103
- paragraphs = [p.get_text() for p in soup.find_all('p')]
104
- articles = [article.get_text() for article in soup.find_all('article')]
105
-
106
- content = {
107
- "headings": headings,
108
- "paragraphs": paragraphs,
109
- "articles": articles
110
- }
111
- return content
112
- else:
113
- return {"error": f"Failed to retrieve content, status code: {response.status_code}"}
114
- except Exception as e:
115
- return {"error": f"An error occurred: {str(e)}"}
116
-
117
- # Function to check if the input contains a URL
118
- def extract_url_from_text(text):
119
- url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
120
- urls = re.findall(url_pattern, text)
121
- return urls
122
-
123
- # Function to generate a response using the model with adjustable parameters
124
- def generate_response(prompt):
125
- urls = extract_url_from_text(prompt)
126
-
127
- if urls:
128
- # If a URL is detected, crawl the webpage and extract content
129
- url_content = fetch_website_content(urls[0]) # Crawl only the first URL for simplicity
130
- if 'error' in url_content:
131
- return url_content['error']
132
- else:
133
- return f"Headings: {url_content['headings']}\n\nParagraphs: {url_content['paragraphs']}\n\nArticles: {url_content['articles']}"
134
- else:
135
- # If no URL, proceed with generating a response from the model
136
- inputs = tokenizer(prompt, return_tensors="pt")
137
-
138
- # Pass attention_mask and set pad_token_id
139
- outputs = model.generate(
140
- inputs.input_ids,
141
- attention_mask=inputs.attention_mask,
142
- max_length=max_length,
143
- do_sample=True,
144
- temperature=temperature,
145
- top_p=top_p,
146
- top_k=top_k,
147
- repetition_penalty=repetition_penalty,
148
- pad_token_id=tokenizer.eos_token_id # Set pad_token_id
149
- )
150
-
151
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
152
- return response
153
 
154
  # Function to save chat logs for later fine-tuning
155
  def save_chat_data(chat_data):
156
- with open('chat_data.json', 'w') as f:
157
  json.dump(chat_data, f, indent=4)
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # Containers for chat history and user input
160
  response_container = st.container()
161
  container = st.container()
162
 
163
  with container:
164
- with st.form(key='user_input_form'):
165
- user_input = st.text_area("You:", key='user_input', height=100)
166
  submit_button = st.form_submit_button("Send")
167
 
168
  if submit_button and user_input:
@@ -172,32 +225,97 @@ with container:
172
  inference_time = end_time - start_time
173
 
174
  # Append user input and model output to session state
175
- st.session_state['past'].append(user_input)
176
- st.session_state['generated'].append(output)
177
- st.session_state['model_name'].append(model_name)
178
 
179
  # Log chat data for future training
180
- st.session_state['chat_data'].append({
181
- "user_input": user_input,
182
- "model_response": output
183
- })
184
 
185
  # Save chat data to a file (this could be used later for training)
186
- save_chat_data(st.session_state['chat_data'])
187
 
188
  # Calculate tokens and cost
189
- num_tokens = len(tokenizer.encode(user_input)) + len(tokenizer.encode(output))
190
- st.session_state['total_tokens'].append(num_tokens)
191
- cost_per_1000_tokens = 0.0001
192
- cost = cost_per_1000_tokens * (num_tokens / 1000)
193
- st.session_state['total_cost'] += cost
194
 
195
  # Display chat history
196
  with response_container:
197
- for i in range(len(st.session_state['generated'])):
198
- message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
199
- message(st.session_state['generated'][i], key=str(i))
200
- st.write(f"Model: {st.session_state['model_name'][i]}")
201
- st.write(f"Tokens: {st.session_state['total_tokens'][i]}")
202
- st.write(f"Inference Time: {inference_time:.4f} seconds")
203
- st.write(f"Cost: ${cost:.5f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
  import re
 
4
  import json
 
5
  import streamlit as st
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,AutoTokenizer, TextDataset, DataCollatorForLanguageModeling
7
  from streamlit_chat import message
8
+ from datasets import load_dataset # تعديل لاستخدام مكتبة datasets
9
+ from pathlib import Path
10
+ import torch
11
+ from PyPDF2 import PdfReader
12
+ import requests
13
+ from bs4 import BeautifulSoup
14
 
15
  # Set page title and icon
16
+ st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:")
17
 
18
+ # Custom CSS for styling chat messages and buttons
19
  st.markdown(
20
  """
21
  <style>
 
27
  }
28
  .stTextArea textarea {
29
  background-color: #f5f5f5;
 
30
  }
31
  .stDownloadButton>button {
32
  background-color: #4CAF50;
33
+ color: white;
34
+ }
35
+ .stMessageContainer {
36
+ border-radius: 15px;
37
+ padding: 10px;
38
+ margin: 10px 0;
39
+ }
40
+ .stMessage--user {
41
+ background-color: #dfe7f3;
42
+ border-left: 6px solid #006699;
43
+ }
44
+ .stMessage--assistant {
45
+ background-color: #f3f3f3;
46
+ border-left: 6px solid #4CAF50;
47
+ }
48
+ pre {
49
+ background-color: #f5f5f5;
50
+ border-left: 6px solid #dfe7f3;
51
+ padding: 10px;
52
+ font-size: 14px;
53
+ border-radius: 8px;
54
  }
55
  </style>
56
+ """,
57
+ unsafe_allow_html=True,
58
  )
59
 
 
 
60
  # Initialize session state variables
61
+ if "generated" not in st.session_state:
62
+ st.session_state["generated"] = []
63
+ if "past" not in st.session_state:
64
+ st.session_state["past"] = []
65
+ if "messages" not in st.session_state:
66
+ st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
67
+ if "model_name" not in st.session_state:
68
+ st.session_state["model_name"] = []
69
+ if "total_tokens" not in st.session_state:
70
+ st.session_state["total_tokens"] = []
71
+ if "total_cost" not in st.session_state:
72
+ st.session_state["total_cost"] = 0.0
73
+ if "chat_data" not in st.session_state:
74
+ st.session_state["chat_data"] = [] # For storing the chat logs
75
+ if "uploaded_docs" not in st.session_state:
76
+ st.session_state["uploaded_docs"] = [] # For storing uploaded document content
77
+ if "web_data" not in st.session_state:
78
+ st.session_state["web_data"] = [] # For storing web scraped data
79
+ if "uploaded_file_path" not in st.session_state:
80
+ st.session_state["uploaded_file_path"] = "" # Store the path of saved files
81
 
82
  # Sidebar - Model Selection, Style Parameters, and Cost Display
83
  st.sidebar.title("Model Selection")
84
+ model_name = "gpt2"
85
 
86
  # Parameters to adjust the response style and creativity
87
  st.sidebar.title("Response Style Controls")
88
+ temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
89
+ top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
90
  top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
91
  repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
92
+ max_length = st.sidebar.slider("Max Length", min_value=100, max_value=4048, value=400, step=10)
93
 
94
+ # Load the model and tokenizer
95
  @st.cache_resource
96
+ def load_model_and_tokenizer():
97
+ model_path = "C:/Users/MC/Ollama_UI/gpt2-finetuned/checkpoint-416" # المسار المحلي للنموذج
98
+ tokenizer = AutoTokenizer.from_pretrained(model_path, clean_up_tokenization_spaces=True)
99
+ model = AutoModelForCausalLM.from_pretrained(model_path)
 
 
 
 
 
100
  return tokenizer, model
101
 
102
+ tokenizer, model = load_model_and_tokenizer()
103
+ # Function to generate a response using the model with updated generation configuration
104
+ # إعداد متغيرات TrainingArguments مع تحسينات
105
+ tokenizer.pad_token = tokenizer.eos_token # لضمان أن المفكرة تستخدم رمز eos كـ pad token
106
 
107
+ def generate_response(prompt):
108
+ context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
109
+ inputs = tokenizer(context, return_tensors="pt")
110
+
111
+ generation_config = {
112
+ "max_length": max_length,
113
+ "temperature": temperature if do_sample else None,
114
+ "top_p": top_p if do_sample else None,
115
+ "top_k": top_k,
116
+ "repetition_penalty": repetition_penalty,
117
+ "pad_token_id": tokenizer.eos_token_id,
118
+ "do_sample": do_sample
119
+ }
120
+
121
+ outputs = model.generate(
122
+ inputs.input_ids,
123
+ attention_mask=inputs.attention_mask,
124
+ **generation_config
125
+ )
126
+
127
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
128
+ return response
129
+ # Set do_sample to True
130
+ do_sample = True
131
  # Function to reset the session
132
  def reset_session():
133
+ st.session_state["generated"] = []
134
+ st.session_state["past"] = []
135
+ st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
136
+ st.session_state["model_name"] = []
137
+ st.session_state["total_tokens"] = []
138
+ st.session_state["total_cost"] = 0.0
139
+ st.session_state["chat_data"] = [] # Reset chat logs
140
+ st.session_state["uploaded_docs"] = [] # Reset uploaded docs
141
+ st.session_state["web_data"] = [] # Reset web data
142
+
143
 
144
  # Reset chat button in sidebar
145
  reset_button = st.sidebar.button("Reset Chat")
146
  if reset_button:
147
  reset_session()
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  # Function to save chat logs for later fine-tuning
151
  def save_chat_data(chat_data):
152
+ with open("chat_data.json", "w") as f:
153
  json.dump(chat_data, f, indent=4)
154
 
155
+
156
+ # Function to handle uploaded text or PDF files and convert PDF to txt
157
+ def handle_uploaded_file(uploaded_file):
158
+ dataset_dir = "C:/Users/MC/Ollama_UI/datasets"
159
+ dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
160
+
161
+ # Check if the file is a PDF
162
+ if uploaded_file.type == "application/pdf":
163
+ # Read and extract text from the PDF
164
+ pdf_reader = PdfReader(uploaded_file)
165
+ text = ""
166
+ for page in pdf_reader.pages:
167
+ text += page.extract_text()
168
+
169
+ # Save extracted text as a .txt file
170
+ with open(dataset_path, "w") as f:
171
+ f.write(text)
172
+ st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
173
+ else:
174
+ # If it's a text file, save it as is
175
+ with open(dataset_path, "wb") as f:
176
+ f.write(uploaded_file.getbuffer())
177
+ st.success(f"File saved to {dataset_path}")
178
+
179
+ st.session_state["uploaded_file_path"] = str(dataset_path)
180
+
181
+
182
+ # Add a file uploader for various formats
183
+ st.sidebar.title("Upload Documents")
184
+ uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])
185
+
186
+ # Process uploaded file
187
+ if uploaded_file is not None:
188
+ handle_uploaded_file(uploaded_file)
189
+
190
+
191
+ # Function to fetch and scrape website content
192
+ def handle_web_link(url):
193
+ response = requests.get(url)
194
+ if response.status_code == 200:
195
+ soup = BeautifulSoup(response.content, "html.parser")
196
+ text = soup.get_text()
197
+ st.session_state["web_data"].append(text)
198
+ st.success(f"Content from {url} saved successfully!")
199
+ else:
200
+ st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
201
+
202
+
203
+ # Add a text box for entering website links
204
+ st.sidebar.title("Add Website Links")
205
+ web_link = st.sidebar.text_input("Enter Website URL")
206
+
207
+ # Process web link
208
+ if web_link:
209
+ handle_web_link(web_link)
210
+
211
+
212
  # Containers for chat history and user input
213
  response_container = st.container()
214
  container = st.container()
215
 
216
  with container:
217
+ with st.form(key="user_input_form"):
218
+ user_input = st.text_area("You:", key="user_input", height=100)
219
  submit_button = st.form_submit_button("Send")
220
 
221
  if submit_button and user_input:
 
225
  inference_time = end_time - start_time
226
 
227
  # Append user input and model output to session state
228
+ st.session_state["past"].append(user_input)
229
+ st.session_state["generated"].append(output)
230
+ st.session_state["model_name"].append(model_name)
231
 
232
  # Log chat data for future training
233
+ st.session_state["chat_data"].append(
234
+ {"user_input": user_input, "model_response": output}
235
+ )
 
236
 
237
  # Save chat data to a file (this could be used later for training)
238
+ save_chat_data(st.session_state["chat_data"])
239
 
240
  # Calculate tokens and cost
241
+
 
 
 
 
242
 
243
  # Display chat history
244
  with response_container:
245
+ for i in range(len(st.session_state["generated"])):
246
+ message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
247
+ message(st.session_state["generated"][i], key=str(i))
248
+
249
+
250
+ # Function to fine-tune the model using uploaded dataset
251
+ def fine_tune_model():
252
+ uploaded_file_path = st.session_state.get("uploaded_file_path", None)
253
+ if not uploaded_file_path:
254
+ st.warning("يرجى تحميل dataset لتدريب النموذج.")
255
+ return
256
+
257
+ # تحميل البيانات النصية أو CSV
258
+ if uploaded_file_path.endswith('.txt'):
259
+ dataset = load_dataset('text', data_files=uploaded_file_path, split='train')
260
+ elif uploaded_file_path.endswith('.csv'):
261
+ dataset = load_dataset('csv', data_files=uploaded_file_path, split='train')
262
+
263
+ # معالجة البيانات: تحويل النصوص إلى رموز (tokenization)
264
+ def tokenize_function(examples):
265
+ return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
266
+
267
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
268
+
269
+ # إعداد الـ collator لعدم استخدام الـ mask language modeling
270
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
271
+
272
+ # التحقق مما إذا كان النظام يستخدم GPU أم لا
273
+ use_fp16 = torch.cuda.is_available() # تفعيل fp16 فقط إذا كان GPU متاحًا
274
+
275
+ # إعداد متغيرات TrainingArguments
276
+ training_args = TrainingArguments(
277
+ output_dir='./gpt2-finetuned',
278
+ overwrite_output_dir=True,
279
+ num_train_epochs=4,
280
+ per_device_train_batch_size=3,
281
+ per_device_eval_batch_size=3,
282
+ save_steps=500,
283
+ eval_strategy="steps",
284
+ eval_steps=500,
285
+ learning_rate=2e-5,
286
+ weight_decay=0.01,
287
+ logging_dir='./logs',
288
+ logging_steps=100,
289
+ save_total_limit=3,
290
+ load_best_model_at_end=True,
291
+ metric_for_best_model='accuracy',
292
+ greater_is_better=True,
293
+ fp16=use_fp16, # تفعيل fp16 فقط إذا كان GPU متاحًا
294
+ remove_unused_columns=False, # تعطيل هذا الخيار لحل مشكلة عدم توافق الأعمدة
295
+ )
296
+
297
+ # تهيئة الـ Trainer
298
+ trainer = Trainer(
299
+ model=model,
300
+ args=training_args,
301
+ data_collator=data_collator,
302
+ train_dataset=tokenized_dataset,
303
+ )
304
+
305
+ # البدء في التدريب
306
+ trainer.train()
307
+
308
+ st.success("تم إكمال تدريب النموذج بنجاح.")
309
+
310
+ # واجهة Streamlit لتحميل dataset وبدء التدريب
311
+ st.title("Fine-tune GPT-2 Model")
312
+
313
+ uploaded_file = st.file_uploader("Upload your dataset (TXT or CSV)", type=['txt', 'csv'])
314
+ if uploaded_file:
315
+ st.session_state["uploaded_file_path"] = uploaded_file.name
316
+ with open(uploaded_file.name, "wb") as f:
317
+ f.write(uploaded_file.getbuffer())
318
+ st.success(f"File {uploaded_file.name} uploaded successfully.")
319
+
320
+ if st.button("Start Fine-tuning"):
321
+ fine_tune_model()