Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import re | |
| import json | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments,AutoTokenizer, TextDataset, DataCollatorForLanguageModeling | |
| from streamlit_chat import message | |
| from datasets import load_dataset # تعديل لاستخدام مكتبة datasets | |
| from pathlib import Path | |
| import torch | |
| from PyPDF2 import PdfReader | |
| import requests | |
| from bs4 import BeautifulSoup | |
| # Set page title and icon | |
| st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:") | |
| # Custom CSS for styling chat messages and buttons | |
| st.markdown( | |
| """ | |
| <style> | |
| .stButton>button { | |
| background-color: #4CAF50; | |
| color: white; | |
| border-radius: 12px; | |
| padding: 10px 24px; | |
| } | |
| .stTextArea textarea { | |
| background-color: #f5f5f5; | |
| } | |
| .stDownloadButton>button { | |
| background-color: #4CAF50; | |
| color: white; | |
| } | |
| .stMessageContainer { | |
| border-radius: 15px; | |
| padding: 10px; | |
| margin: 10px 0; | |
| } | |
| .stMessage--user { | |
| background-color: #dfe7f3; | |
| border-left: 6px solid #006699; | |
| } | |
| .stMessage--assistant { | |
| background-color: #f3f3f3; | |
| border-left: 6px solid #4CAF50; | |
| } | |
| pre { | |
| background-color: #f5f5f5; | |
| border-left: 6px solid #dfe7f3; | |
| padding: 10px; | |
| font-size: 14px; | |
| border-radius: 8px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Initialize session state variables | |
| if "generated" not in st.session_state: | |
| st.session_state["generated"] = [] | |
| if "past" not in st.session_state: | |
| st.session_state["past"] = [] | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}] | |
| if "model_name" not in st.session_state: | |
| st.session_state["model_name"] = [] | |
| if "total_tokens" not in st.session_state: | |
| st.session_state["total_tokens"] = [] | |
| if "total_cost" not in st.session_state: | |
| st.session_state["total_cost"] = 0.0 | |
| if "chat_data" not in st.session_state: | |
| st.session_state["chat_data"] = [] # For storing the chat logs | |
| if "uploaded_docs" not in st.session_state: | |
| st.session_state["uploaded_docs"] = [] # For storing uploaded document content | |
| if "web_data" not in st.session_state: | |
| st.session_state["web_data"] = [] # For storing web scraped data | |
| if "uploaded_file_path" not in st.session_state: | |
| st.session_state["uploaded_file_path"] = "" # Store the path of saved files | |
| # Sidebar - Model Selection, Style Parameters, and Cost Display | |
| st.sidebar.title("Model Selection") | |
| model_name = "gpt2" | |
| # Parameters to adjust the response style and creativity | |
| st.sidebar.title("Response Style Controls") | |
| temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1) | |
| top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05) | |
| top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1) | |
| repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1) | |
| max_length = st.sidebar.slider("Max Length", min_value=100, max_value=4048, value=400, step=10) | |
| # Load the model and tokenizer | |
| def load_model_and_tokenizer(): | |
| model_path = "gpt2" # المسار المحلي للنموذج | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, clean_up_tokenization_spaces=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_path) | |
| return tokenizer, model | |
| tokenizer, model = load_model_and_tokenizer() | |
| # Function to generate a response using the model with updated generation configuration | |
| # إعداد متغيرات TrainingArguments مع تحسينات | |
| tokenizer.pad_token = tokenizer.eos_token # لضمان أن المفكرة تستخدم رمز eos كـ pad token | |
| def generate_response(prompt): | |
| context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt | |
| inputs = tokenizer(context, return_tensors="pt") | |
| generation_config = { | |
| "max_length": max_length, | |
| "temperature": temperature if do_sample else None, | |
| "top_p": top_p if do_sample else None, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "do_sample": do_sample | |
| } | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| **generation_config | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # Set do_sample to True | |
| do_sample = True | |
| # Function to reset the session | |
| def reset_session(): | |
| st.session_state["generated"] = [] | |
| st.session_state["past"] = [] | |
| st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}] | |
| st.session_state["model_name"] = [] | |
| st.session_state["total_tokens"] = [] | |
| st.session_state["total_cost"] = 0.0 | |
| st.session_state["chat_data"] = [] # Reset chat logs | |
| st.session_state["uploaded_docs"] = [] # Reset uploaded docs | |
| st.session_state["web_data"] = [] # Reset web data | |
| # Reset chat button in sidebar | |
| reset_button = st.sidebar.button("Reset Chat") | |
| if reset_button: | |
| reset_session() | |
| # Function to save chat logs for later fine-tuning | |
| def save_chat_data(chat_data): | |
| with open("chat_data.json", "w") as f: | |
| json.dump(chat_data, f, indent=4) | |
| # Function to handle uploaded text or PDF files and convert PDF to txt | |
| def handle_uploaded_file(uploaded_file): | |
| dataset_dir = "./dataset" | |
| dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt" | |
| # Check if the file is a PDF | |
| if uploaded_file.type == "application/pdf": | |
| # Read and extract text from the PDF | |
| pdf_reader = PdfReader(uploaded_file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| # Save extracted text as a .txt file | |
| with open(dataset_path, "w") as f: | |
| f.write(text) | |
| st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}") | |
| else: | |
| # If it's a text file, save it as is | |
| with open(dataset_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.success(f"File saved to {dataset_path}") | |
| st.session_state["uploaded_file_path"] = str(dataset_path) | |
| # Add a file uploader for various formats | |
| st.sidebar.title("Upload Documents") | |
| uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"]) | |
| # Process uploaded file | |
| if uploaded_file is not None: | |
| handle_uploaded_file(uploaded_file) | |
| # Function to fetch and scrape website content | |
| def handle_web_link(url): | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| soup = BeautifulSoup(response.content, "html.parser") | |
| text = soup.get_text() | |
| st.session_state["web_data"].append(text) | |
| st.success(f"Content from {url} saved successfully!") | |
| else: | |
| st.error(f"Failed to retrieve content from {url}. Status code: {response.status_code}") | |
| # Add a text box for entering website links | |
| st.sidebar.title("Add Website Links") | |
| web_link = st.sidebar.text_input("Enter Website URL") | |
| # Process web link | |
| if web_link: | |
| handle_web_link(web_link) | |
| # Containers for chat history and user input | |
| response_container = st.container() | |
| container = st.container() | |
| with container: | |
| with st.form(key="user_input_form"): | |
| user_input = st.text_area("You:", key="user_input", height=100) | |
| submit_button = st.form_submit_button("Send") | |
| if submit_button and user_input: | |
| start_time = time.time() | |
| output = generate_response(user_input) | |
| end_time = time.time() | |
| inference_time = end_time - start_time | |
| # Append user input and model output to session state | |
| st.session_state["past"].append(user_input) | |
| st.session_state["generated"].append(output) | |
| st.session_state["model_name"].append(model_name) | |
| # Log chat data for future training | |
| st.session_state["chat_data"].append( | |
| {"user_input": user_input, "model_response": output} | |
| ) | |
| # Save chat data to a file (this could be used later for training) | |
| save_chat_data(st.session_state["chat_data"]) | |
| # Calculate tokens and cost | |
| # Display chat history | |
| with response_container: | |
| for i in range(len(st.session_state["generated"])): | |
| message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") | |
| message(st.session_state["generated"][i], key=str(i)) | |
| # Function to fine-tune the model using uploaded dataset | |
| def fine_tune_model(): | |
| uploaded_file_path = st.session_state.get("uploaded_file_path", None) | |
| if not uploaded_file_path: | |
| st.warning("يرجى تحميل dataset لتدريب النموذج.") | |
| return | |
| # تحميل البيانات النصية أو CSV | |
| if uploaded_file_path.endswith('.txt'): | |
| dataset = load_dataset('text', data_files=uploaded_file_path, split='train') | |
| elif uploaded_file_path.endswith('.csv'): | |
| dataset = load_dataset('csv', data_files=uploaded_file_path, split='train') | |
| # معالجة البيانات: تحويل النصوص إلى رموز (tokenization) | |
| def tokenize_function(examples): | |
| return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512) | |
| tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"]) | |
| # إعداد الـ collator لعدم استخدام الـ mask language modeling | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| # التحقق مما إذا كان النظام يستخدم GPU أم لا | |
| use_fp16 = torch.cuda.is_available() # تفعيل fp16 فقط إذا كان GPU متاحًا | |
| # إعداد متغيرات TrainingArguments | |
| training_args = TrainingArguments( | |
| output_dir='./gpt2-finetuned', | |
| overwrite_output_dir=True, | |
| num_train_epochs=4, | |
| per_device_train_batch_size=3, | |
| per_device_eval_batch_size=3, | |
| save_steps=500, | |
| eval_strategy="steps", | |
| eval_steps=500, | |
| learning_rate=2e-5, | |
| weight_decay=0.01, | |
| logging_dir='./logs', | |
| logging_steps=100, | |
| save_total_limit=3, | |
| load_best_model_at_end=True, | |
| metric_for_best_model='accuracy', | |
| greater_is_better=True, | |
| fp16=use_fp16, # تفعيل fp16 فقط إذا كان GPU متاحًا | |
| remove_unused_columns=False, # تعطيل هذا الخيار لحل مشكلة عدم توافق الأعمدة | |
| ) | |
| # تهيئة الـ Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| train_dataset=tokenized_dataset, | |
| ) | |
| # البدء في التدريب | |
| trainer.train() | |
| st.success("تم إكمال تدريب النموذج بنجاح.") | |
| # واجهة Streamlit لتحميل dataset وبدء التدريب | |
| st.title("Fine-tune GPT-2 Model") | |
| uploaded_file = st.file_uploader("Upload your dataset (TXT or CSV)", type=['txt', 'csv']) | |
| if uploaded_file: | |
| st.session_state["uploaded_file_path"] = uploaded_file.name | |
| with open(uploaded_file.name, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.success(f"File {uploaded_file.name} uploaded successfully.") | |
| if st.button("Start Fine-tuning"): | |
| fine_tune_model() |