Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import re | |
| import json | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling | |
| from streamlit_chat import message | |
| from pathlib import Path | |
| import torch | |
| from PyPDF2 import PdfReader | |
| import requests | |
| from bs4 import BeautifulSoup | |
| # Set page title and icon | |
| st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:") | |
| # Custom CSS for styling chat messages and buttons | |
| st.markdown( | |
| """ | |
| <style> | |
| .stButton>button { | |
| background-color: #4CAF50; | |
| color: white; | |
| border-radius: 12px; | |
| padding: 10px 24px; | |
| } | |
| .stTextArea textarea { | |
| background-color: #f5f5f5; | |
| } | |
| .stDownloadButton>button { | |
| background-color: #4CAF50; | |
| color: white; | |
| } | |
| .stMessageContainer { | |
| border-radius: 15px; | |
| padding: 10px; | |
| margin: 10px 0; | |
| } | |
| .stMessage--user { | |
| background-color: #dfe7f3; | |
| border-left: 6px solid #006699; | |
| } | |
| .stMessage--assistant { | |
| background-color: #f3f3f3; | |
| border-left: 6px solid #4CAF50; | |
| } | |
| pre { | |
| background-color: #f5f5f5; | |
| border-left: 6px solid #dfe7f3; | |
| padding: 10px; | |
| font-size: 14px; | |
| border-radius: 8px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Initialize session state variables | |
| if "generated" not in st.session_state: | |
| st.session_state["generated"] = [] | |
| if "past" not in st.session_state: | |
| st.session_state["past"] = [] | |
| if "messages" not in st.session_state: | |
| st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}] | |
| if "chat_data" not in st.session_state: | |
| st.session_state["chat_data"] = [] # For storing the chat logs | |
| if "uploaded_docs" not in st.session_state: | |
| st.session_state["uploaded_docs"] = [] # For storing uploaded document content | |
| if "web_data" not in st.session_state: | |
| st.session_state["web_data"] = [] # For storing web scraped data | |
| # Sidebar - Model Selection, Style Parameters, and Cost Display | |
| st.sidebar.title("Model Selection") | |
| model_name = "gpt2" | |
| # Parameters to adjust the response style and creativity | |
| st.sidebar.title("Response Style Controls") | |
| temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1) | |
| top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05) | |
| top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1) | |
| repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1) | |
| max_length = st.sidebar.slider("Max Length", min_value=100, max_value=1024, value=800, step=10) | |
| def load_model_and_tokenizer(): | |
| model_path = "gpt2" # Path to the local model directory | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_path) | |
| return tokenizer, model | |
| tokenizer, model = load_model_and_tokenizer() | |
| def generate_response(prompt): | |
| """ | |
| Generate a response using the GPT-2 model, including document and web data context. | |
| """ | |
| context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt | |
| inputs = tokenizer(context, return_tensors="pt") | |
| generation_config = { | |
| "max_length": max_length, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "do_sample": True # Always sample tokens | |
| } | |
| outputs = model.generate( | |
| inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| **generation_config | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Reset session | |
| def reset_session(): | |
| """ Reset all session state variables. """ | |
| st.session_state["generated"] = [] | |
| st.session_state["past"] = [] | |
| st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}] | |
| st.session_state["chat_data"] = [] # Reset chat logs | |
| st.session_state["uploaded_docs"] = [] # Reset uploaded docs | |
| st.session_state["web_data"] = [] # Reset web data | |
| reset_button = st.sidebar.button("Reset Chat") | |
| if reset_button: | |
| reset_session() | |
| def save_chat_data(chat_data): | |
| """ Save chat logs for future fine-tuning or reference. """ | |
| with open("chat_data.json", "w") as f: | |
| json.dump(chat_data, f, indent=4) | |
| def handle_uploaded_file(uploaded_file): | |
| dataset_dir = "./datasets" | |
| dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt" | |
| # Check if the file is a PDF | |
| if uploaded_file.type == "application/pdf": | |
| pdf_reader = PdfReader(uploaded_file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() | |
| if not text: | |
| st.error("Failed to extract text from the PDF.") | |
| return None # Return None if text extraction fails | |
| with open(dataset_path, "w") as f: | |
| f.write(text) | |
| st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}") | |
| else: | |
| with open(dataset_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.success(f"File saved to {dataset_path}") | |
| return str(dataset_path) # Return the path to the saved file | |
| def handle_web_link(url): | |
| """ Fetch and scrape text content from a website. """ | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.content, "html.parser") | |
| text = soup.get_text() | |
| st.session_state["web_data"].append(text) | |
| st.success(f"Content from {url} saved successfully!") | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"Failed to retrieve content: {e}") | |
| st.sidebar.title("Add Website Links") | |
| web_link = st.sidebar.text_input("Enter Website URL") | |
| if web_link: | |
| handle_web_link(web_link) | |
| # Chat interface | |
| response_container = st.container() | |
| container = st.container() | |
| with container: | |
| with st.form(key="user_input_form"): | |
| user_input = st.text_area("You:", key="user_input", height=100) | |
| submit_button = st.form_submit_button("Send") | |
| if submit_button and user_input: | |
| start_time = time.time() | |
| output = generate_response(user_input) | |
| inference_time = time.time() - start_time | |
| st.session_state["past"].append(user_input) | |
| st.session_state["generated"].append(output) | |
| # Log chat data for future training | |
| st.session_state["chat_data"].append( | |
| {"user_input": user_input, "model_response": output} | |
| ) | |
| save_chat_data(st.session_state["chat_data"]) | |
| with response_container: | |
| for i in range(len(st.session_state["generated"])): | |
| message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") | |
| message(st.session_state["generated"][i], key=str(i)) | |
| def fine_tune_model(): | |
| uploaded_file_path = st.session_state.get("uploaded_file_path", "") | |
| if not uploaded_file_path: | |
| st.warning("Please upload a text or PDF dataset to fine-tune the model.") | |
| return | |
| # Prepare dataset for fine-tuning (using the uploaded .txt file) | |
| try: | |
| with open(uploaded_file_path, "r") as f: | |
| text = f.read().strip() # Ensure that the file is not empty | |
| if len(text) == 0: | |
| raise ValueError("The dataset is empty.") | |
| train_dataset = TextDataset( | |
| tokenizer=tokenizer, | |
| file_path=uploaded_file_path, # Ensure this path is a .txt file | |
| block_size=128, | |
| ) | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| # Define training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./gpt2-finetuned", | |
| overwrite_output_dir=True, | |
| num_train_epochs=3, | |
| per_device_train_batch_size=8, | |
| save_steps=10_000, | |
| save_total_limit=2, | |
| logging_dir="./logs", | |
| logging_steps=200, | |
| ) | |
| # Initialize the Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| train_dataset=train_dataset, | |
| ) | |
| # Fine-tune the model | |
| trainer.train() | |
| st.success("Model fine-tuning completed successfully.") | |
| except Exception as e: | |
| st.error(f"Error during fine-tuning: {str(e)}") | |
| # Sidebar file upload | |
| st.sidebar.title("Upload Documents") | |
| uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"]) | |
| # Process uploaded file | |
| if uploaded_file is not None: | |
| file_path = handle_uploaded_file(uploaded_file) | |
| if file_path: | |
| st.session_state["uploaded_file_path"] = file_path | |
| # Add a button to trigger fine-tuning | |
| st.sidebar.title("Fine-Tune Model") | |
| fine_tune_button = st.sidebar.button("Fine-Tune GPT-2") | |
| if fine_tune_button: | |
| fine_tune_model() |