Spaces:
Sleeping
Sleeping
File size: 11,703 Bytes
cf6606c c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 1baf5eb 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 d3ad937 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 c035add 3dfe622 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 | import os
import time
import re
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,AutoTokenizer, TextDataset, DataCollatorForLanguageModeling
from streamlit_chat import message
from datasets import load_dataset # تعديل لاستخدام مكتبة datasets
from pathlib import Path
import torch
from PyPDF2 import PdfReader
import requests
from bs4 import BeautifulSoup
# Set page title and icon
st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:")
# Custom CSS for styling chat messages and buttons
st.markdown(
"""
<style>
.stButton>button {
background-color: #4CAF50;
color: white;
border-radius: 12px;
padding: 10px 24px;
}
.stTextArea textarea {
background-color: #f5f5f5;
}
.stDownloadButton>button {
background-color: #4CAF50;
color: white;
}
.stMessageContainer {
border-radius: 15px;
padding: 10px;
margin: 10px 0;
}
.stMessage--user {
background-color: #dfe7f3;
border-left: 6px solid #006699;
}
.stMessage--assistant {
background-color: #f3f3f3;
border-left: 6px solid #4CAF50;
}
pre {
background-color: #f5f5f5;
border-left: 6px solid #dfe7f3;
padding: 10px;
font-size: 14px;
border-radius: 8px;
}
</style>
""",
unsafe_allow_html=True,
)
# Initialize session state variables
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
if "model_name" not in st.session_state:
st.session_state["model_name"] = []
if "total_tokens" not in st.session_state:
st.session_state["total_tokens"] = []
if "total_cost" not in st.session_state:
st.session_state["total_cost"] = 0.0
if "chat_data" not in st.session_state:
st.session_state["chat_data"] = [] # For storing the chat logs
if "uploaded_docs" not in st.session_state:
st.session_state["uploaded_docs"] = [] # For storing uploaded document content
if "web_data" not in st.session_state:
st.session_state["web_data"] = [] # For storing web scraped data
if "uploaded_file_path" not in st.session_state:
st.session_state["uploaded_file_path"] = "" # Store the path of saved files
# Sidebar - Model Selection, Style Parameters, and Cost Display
st.sidebar.title("Model Selection")
model_name = "gpt2"
# Parameters to adjust the response style and creativity
st.sidebar.title("Response Style Controls")
temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
max_length = st.sidebar.slider("Max Length", min_value=100, max_value=4048, value=400, step=10)
# Load the model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
model_path = "gpt2" # المسار المحلي للنموذج
tokenizer = AutoTokenizer.from_pretrained(model_path, clean_up_tokenization_spaces=True)
model = AutoModelForCausalLM.from_pretrained(model_path)
return tokenizer, model
tokenizer, model = load_model_and_tokenizer()
# Function to generate a response using the model with updated generation configuration
# إعداد متغيرات TrainingArguments مع تحسينات
tokenizer.pad_token = tokenizer.eos_token # لضمان أن المفكرة تستخدم رمز eos كـ pad token
def generate_response(prompt):
context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
inputs = tokenizer(context, return_tensors="pt")
generation_config = {
"max_length": max_length,
"temperature": temperature if do_sample else None,
"top_p": top_p if do_sample else None,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"pad_token_id": tokenizer.eos_token_id,
"do_sample": do_sample
}
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_config
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Set do_sample to True
do_sample = True
# Function to reset the session
def reset_session():
st.session_state["generated"] = []
st.session_state["past"] = []
st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
st.session_state["model_name"] = []
st.session_state["total_tokens"] = []
st.session_state["total_cost"] = 0.0
st.session_state["chat_data"] = [] # Reset chat logs
st.session_state["uploaded_docs"] = [] # Reset uploaded docs
st.session_state["web_data"] = [] # Reset web data
# Reset chat button in sidebar
reset_button = st.sidebar.button("Reset Chat")
if reset_button:
reset_session()
# Function to save chat logs for later fine-tuning
def save_chat_data(chat_data):
with open("chat_data.json", "w") as f:
json.dump(chat_data, f, indent=4)
# Function to handle uploaded text or PDF files and convert PDF to txt
def handle_uploaded_file(uploaded_file):
dataset_dir = "./dataset"
dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
# Check if the file is a PDF
if uploaded_file.type == "application/pdf":
# Read and extract text from the PDF
pdf_reader = PdfReader(uploaded_file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
# Save extracted text as a .txt file
with open(dataset_path, "w") as f:
f.write(text)
st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
else:
# If it's a text file, save it as is
with open(dataset_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"File saved to {dataset_path}")
st.session_state["uploaded_file_path"] = str(dataset_path)
# Add a file uploader for various formats
st.sidebar.title("Upload Documents")
uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])
# Process uploaded file
if uploaded_file is not None:
handle_uploaded_file(uploaded_file)
# Function to fetch and scrape website content
def handle_web_link(url):
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, "html.parser")
text = soup.get_text()
st.session_state["web_data"].append(text)
st.success(f"Content from {url} saved successfully!")
else:
st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
# Add a text box for entering website links
st.sidebar.title("Add Website Links")
web_link = st.sidebar.text_input("Enter Website URL")
# Process web link
if web_link:
handle_web_link(web_link)
# Containers for chat history and user input
response_container = st.container()
container = st.container()
with container:
with st.form(key="user_input_form"):
user_input = st.text_area("You:", key="user_input", height=100)
submit_button = st.form_submit_button("Send")
if submit_button and user_input:
start_time = time.time()
output = generate_response(user_input)
end_time = time.time()
inference_time = end_time - start_time
# Append user input and model output to session state
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)
st.session_state["model_name"].append(model_name)
# Log chat data for future training
st.session_state["chat_data"].append(
{"user_input": user_input, "model_response": output}
)
# Save chat data to a file (this could be used later for training)
save_chat_data(st.session_state["chat_data"])
# Calculate tokens and cost
# Display chat history
with response_container:
for i in range(len(st.session_state["generated"])):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
message(st.session_state["generated"][i], key=str(i))
# Function to fine-tune the model using uploaded dataset
def fine_tune_model():
uploaded_file_path = st.session_state.get("uploaded_file_path", None)
if not uploaded_file_path:
st.warning("يرجى تحميل dataset لتدريب النموذج.")
return
# تحميل البيانات النصية أو CSV
if uploaded_file_path.endswith('.txt'):
dataset = load_dataset('text', data_files=uploaded_file_path, split='train')
elif uploaded_file_path.endswith('.csv'):
dataset = load_dataset('csv', data_files=uploaded_file_path, split='train')
# معالجة البيانات: تحويل النصوص إلى رموز (tokenization)
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# إعداد الـ collator لعدم استخدام الـ mask language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# التحقق مما إذا كان النظام يستخدم GPU أم لا
use_fp16 = torch.cuda.is_available() # تفعيل fp16 فقط إذا كان GPU متاحًا
# إعداد متغيرات TrainingArguments
training_args = TrainingArguments(
output_dir='./gpt2-finetuned',
overwrite_output_dir=True,
num_train_epochs=4,
per_device_train_batch_size=3,
per_device_eval_batch_size=3,
save_steps=500,
eval_strategy="steps",
eval_steps=500,
learning_rate=2e-5,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=100,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model='accuracy',
greater_is_better=True,
fp16=use_fp16, # تفعيل fp16 فقط إذا كان GPU متاحًا
remove_unused_columns=False, # تعطيل هذا الخيار لحل مشكلة عدم توافق الأعمدة
)
# تهيئة الـ Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
)
# البدء في التدريب
trainer.train()
st.success("تم إكمال تدريب النموذج بنجاح.")
# واجهة Streamlit لتحميل dataset وبدء التدريب
st.title("Fine-tune GPT-2 Model")
uploaded_file = st.file_uploader("Upload your dataset (TXT or CSV)", type=['txt', 'csv'])
if uploaded_file:
st.session_state["uploaded_file_path"] = uploaded_file.name
with open(uploaded_file.name, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"File {uploaded_file.name} uploaded successfully.")
if st.button("Start Fine-tuning"):
fine_tune_model() |