fuzzylab / app4.py
odaly's picture
Rename app.py to app4.py
f2793c9 verified
import os
import time
import re
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,AutoTokenizer, TextDataset, DataCollatorForLanguageModeling
from streamlit_chat import message
from datasets import load_dataset # تعديل لاستخدام مكتبة datasets
from pathlib import Path
import torch
from PyPDF2 import PdfReader
import requests
from bs4 import BeautifulSoup
# Set page title and icon
st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:")
# Custom CSS for styling chat messages and buttons
st.markdown(
"""
<style>
.stButton>button {
background-color: #4CAF50;
color: white;
border-radius: 12px;
padding: 10px 24px;
}
.stTextArea textarea {
background-color: #f5f5f5;
}
.stDownloadButton>button {
background-color: #4CAF50;
color: white;
}
.stMessageContainer {
border-radius: 15px;
padding: 10px;
margin: 10px 0;
}
.stMessage--user {
background-color: #dfe7f3;
border-left: 6px solid #006699;
}
.stMessage--assistant {
background-color: #f3f3f3;
border-left: 6px solid #4CAF50;
}
pre {
background-color: #f5f5f5;
border-left: 6px solid #dfe7f3;
padding: 10px;
font-size: 14px;
border-radius: 8px;
}
</style>
""",
unsafe_allow_html=True,
)
# Initialize session state variables
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
if "model_name" not in st.session_state:
st.session_state["model_name"] = []
if "total_tokens" not in st.session_state:
st.session_state["total_tokens"] = []
if "total_cost" not in st.session_state:
st.session_state["total_cost"] = 0.0
if "chat_data" not in st.session_state:
st.session_state["chat_data"] = [] # For storing the chat logs
if "uploaded_docs" not in st.session_state:
st.session_state["uploaded_docs"] = [] # For storing uploaded document content
if "web_data" not in st.session_state:
st.session_state["web_data"] = [] # For storing web scraped data
if "uploaded_file_path" not in st.session_state:
st.session_state["uploaded_file_path"] = "" # Store the path of saved files
# Sidebar - Model Selection, Style Parameters, and Cost Display
st.sidebar.title("Model Selection")
model_name = "gpt2"
# Parameters to adjust the response style and creativity
st.sidebar.title("Response Style Controls")
temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
max_length = st.sidebar.slider("Max Length", min_value=100, max_value=4048, value=400, step=10)
# Load the model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
model_path = "gpt2" # المسار المحلي للنموذج
tokenizer = AutoTokenizer.from_pretrained(model_path, clean_up_tokenization_spaces=True)
model = AutoModelForCausalLM.from_pretrained(model_path)
return tokenizer, model
tokenizer, model = load_model_and_tokenizer()
# Function to generate a response using the model with updated generation configuration
# إعداد متغيرات TrainingArguments مع تحسينات
tokenizer.pad_token = tokenizer.eos_token # لضمان أن المفكرة تستخدم رمز eos كـ pad token
def generate_response(prompt):
context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
inputs = tokenizer(context, return_tensors="pt")
generation_config = {
"max_length": max_length,
"temperature": temperature if do_sample else None,
"top_p": top_p if do_sample else None,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"pad_token_id": tokenizer.eos_token_id,
"do_sample": do_sample
}
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_config
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Set do_sample to True
do_sample = True
# Function to reset the session
def reset_session():
st.session_state["generated"] = []
st.session_state["past"] = []
st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
st.session_state["model_name"] = []
st.session_state["total_tokens"] = []
st.session_state["total_cost"] = 0.0
st.session_state["chat_data"] = [] # Reset chat logs
st.session_state["uploaded_docs"] = [] # Reset uploaded docs
st.session_state["web_data"] = [] # Reset web data
# Reset chat button in sidebar
reset_button = st.sidebar.button("Reset Chat")
if reset_button:
reset_session()
# Function to save chat logs for later fine-tuning
def save_chat_data(chat_data):
with open("chat_data.json", "w") as f:
json.dump(chat_data, f, indent=4)
# Function to handle uploaded text or PDF files and convert PDF to txt
def handle_uploaded_file(uploaded_file):
dataset_dir = "./dataset"
dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
# Check if the file is a PDF
if uploaded_file.type == "application/pdf":
# Read and extract text from the PDF
pdf_reader = PdfReader(uploaded_file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
# Save extracted text as a .txt file
with open(dataset_path, "w") as f:
f.write(text)
st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
else:
# If it's a text file, save it as is
with open(dataset_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"File saved to {dataset_path}")
st.session_state["uploaded_file_path"] = str(dataset_path)
# Add a file uploader for various formats
st.sidebar.title("Upload Documents")
uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])
# Process uploaded file
if uploaded_file is not None:
handle_uploaded_file(uploaded_file)
# Function to fetch and scrape website content
def handle_web_link(url):
response = requests.get(url)
if response.status_code == 200:
soup = BeautifulSoup(response.content, "html.parser")
text = soup.get_text()
st.session_state["web_data"].append(text)
st.success(f"Content from {url} saved successfully!")
else:
st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}")
# Add a text box for entering website links
st.sidebar.title("Add Website Links")
web_link = st.sidebar.text_input("Enter Website URL")
# Process web link
if web_link:
handle_web_link(web_link)
# Containers for chat history and user input
response_container = st.container()
container = st.container()
with container:
with st.form(key="user_input_form"):
user_input = st.text_area("You:", key="user_input", height=100)
submit_button = st.form_submit_button("Send")
if submit_button and user_input:
start_time = time.time()
output = generate_response(user_input)
end_time = time.time()
inference_time = end_time - start_time
# Append user input and model output to session state
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)
st.session_state["model_name"].append(model_name)
# Log chat data for future training
st.session_state["chat_data"].append(
{"user_input": user_input, "model_response": output}
)
# Save chat data to a file (this could be used later for training)
save_chat_data(st.session_state["chat_data"])
# Calculate tokens and cost
# Display chat history
with response_container:
for i in range(len(st.session_state["generated"])):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
message(st.session_state["generated"][i], key=str(i))
# Function to fine-tune the model using uploaded dataset
def fine_tune_model():
uploaded_file_path = st.session_state.get("uploaded_file_path", None)
if not uploaded_file_path:
st.warning("يرجى تحميل dataset لتدريب النموذج.")
return
# تحميل البيانات النصية أو CSV
if uploaded_file_path.endswith('.txt'):
dataset = load_dataset('text', data_files=uploaded_file_path, split='train')
elif uploaded_file_path.endswith('.csv'):
dataset = load_dataset('csv', data_files=uploaded_file_path, split='train')
# معالجة البيانات: تحويل النصوص إلى رموز (tokenization)
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# إعداد الـ collator لعدم استخدام الـ mask language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# التحقق مما إذا كان النظام يستخدم GPU أم لا
use_fp16 = torch.cuda.is_available() # تفعيل fp16 فقط إذا كان GPU متاحًا
# إعداد متغيرات TrainingArguments
training_args = TrainingArguments(
output_dir='./gpt2-finetuned',
overwrite_output_dir=True,
num_train_epochs=4,
per_device_train_batch_size=3,
per_device_eval_batch_size=3,
save_steps=500,
eval_strategy="steps",
eval_steps=500,
learning_rate=2e-5,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=100,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model='accuracy',
greater_is_better=True,
fp16=use_fp16, # تفعيل fp16 فقط إذا كان GPU متاحًا
remove_unused_columns=False, # تعطيل هذا الخيار لحل مشكلة عدم توافق الأعمدة
)
# تهيئة الـ Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=tokenized_dataset,
)
# البدء في التدريب
trainer.train()
st.success("تم إكمال تدريب النموذج بنجاح.")
# واجهة Streamlit لتحميل dataset وبدء التدريب
st.title("Fine-tune GPT-2 Model")
uploaded_file = st.file_uploader("Upload your dataset (TXT or CSV)", type=['txt', 'csv'])
if uploaded_file:
st.session_state["uploaded_file_path"] = uploaded_file.name
with open(uploaded_file.name, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"File {uploaded_file.name} uploaded successfully.")
if st.button("Start Fine-tuning"):
fine_tune_model()