import os import time import re import json import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling from streamlit_chat import message from pathlib import Path import torch from PyPDF2 import PdfReader import requests from bs4 import BeautifulSoup # Set page title and icon st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:") # Custom CSS for styling chat messages and buttons st.markdown( """ """, unsafe_allow_html=True, ) # Initialize session state variables if "generated" not in st.session_state: st.session_state["generated"] = [] if "past" not in st.session_state: st.session_state["past"] = [] if "messages" not in st.session_state: st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}] if "chat_data" not in st.session_state: st.session_state["chat_data"] = [] # For storing the chat logs if "uploaded_docs" not in st.session_state: st.session_state["uploaded_docs"] = [] # For storing uploaded document content if "web_data" not in st.session_state: st.session_state["web_data"] = [] # For storing web scraped data # Sidebar - Model Selection, Style Parameters, and Cost Display st.sidebar.title("Model Selection") model_name = "gpt2" # Parameters to adjust the response style and creativity st.sidebar.title("Response Style Controls") temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1) top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05) top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1) repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1) max_length = st.sidebar.slider("Max Length", min_value=100, max_value=1024, value=800, step=10) @st.cache_resource def load_model_and_tokenizer(): model_path = "gpt2" # Path to the local model directory tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True) model = AutoModelForCausalLM.from_pretrained(model_path) return tokenizer, model tokenizer, model = load_model_and_tokenizer() def generate_response(prompt): """ Generate a response using the GPT-2 model, including document and web data context. """ context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt inputs = tokenizer(context, return_tensors="pt") generation_config = { "max_length": max_length, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "pad_token_id": tokenizer.eos_token_id, "do_sample": True # Always sample tokens } outputs = model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, **generation_config ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Reset session def reset_session(): """ Reset all session state variables. """ st.session_state["generated"] = [] st.session_state["past"] = [] st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}] st.session_state["chat_data"] = [] # Reset chat logs st.session_state["uploaded_docs"] = [] # Reset uploaded docs st.session_state["web_data"] = [] # Reset web data reset_button = st.sidebar.button("Reset Chat") if reset_button: reset_session() def save_chat_data(chat_data): """ Save chat logs for future fine-tuning or reference. """ with open("chat_data.json", "w") as f: json.dump(chat_data, f, indent=4) def handle_uploaded_file(uploaded_file): dataset_dir = "./datasets" dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt" # Check if the file is a PDF if uploaded_file.type == "application/pdf": pdf_reader = PdfReader(uploaded_file) text = "" for page in pdf_reader.pages: text += page.extract_text() if not text: st.error("Failed to extract text from the PDF.") return None # Return None if text extraction fails with open(dataset_path, "w") as f: f.write(text) st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}") else: with open(dataset_path, "wb") as f: f.write(uploaded_file.getbuffer()) st.success(f"File saved to {dataset_path}") return str(dataset_path) # Return the path to the saved file def handle_web_link(url): """ Fetch and scrape text content from a website. """ try: response = requests.get(url) response.raise_for_status() soup = BeautifulSoup(response.content, "html.parser") text = soup.get_text() st.session_state["web_data"].append(text) st.success(f"Content from {url} saved successfully!") except requests.exceptions.RequestException as e: st.error(f"Failed to retrieve content: {e}") st.sidebar.title("Add Website Links") web_link = st.sidebar.text_input("Enter Website URL") if web_link: handle_web_link(web_link) # Chat interface response_container = st.container() container = st.container() with container: with st.form(key="user_input_form"): user_input = st.text_area("You:", key="user_input", height=100) submit_button = st.form_submit_button("Send") if submit_button and user_input: start_time = time.time() output = generate_response(user_input) inference_time = time.time() - start_time st.session_state["past"].append(user_input) st.session_state["generated"].append(output) # Log chat data for future training st.session_state["chat_data"].append( {"user_input": user_input, "model_response": output} ) save_chat_data(st.session_state["chat_data"]) with response_container: for i in range(len(st.session_state["generated"])): message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") message(st.session_state["generated"][i], key=str(i)) def fine_tune_model(): uploaded_file_path = st.session_state.get("uploaded_file_path", "") if not uploaded_file_path: st.warning("Please upload a text or PDF dataset to fine-tune the model.") return # Prepare dataset for fine-tuning (using the uploaded .txt file) try: with open(uploaded_file_path, "r") as f: text = f.read().strip() # Ensure that the file is not empty if len(text) == 0: raise ValueError("The dataset is empty.") train_dataset = TextDataset( tokenizer=tokenizer, file_path=uploaded_file_path, # Ensure this path is a .txt file block_size=128, ) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Define training arguments training_args = TrainingArguments( output_dir="./gpt2-finetuned", overwrite_output_dir=True, num_train_epochs=3, per_device_train_batch_size=8, save_steps=10_000, save_total_limit=2, logging_dir="./logs", logging_steps=200, ) # Initialize the Trainer trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=train_dataset, ) # Fine-tune the model trainer.train() st.success("Model fine-tuning completed successfully.") except Exception as e: st.error(f"Error during fine-tuning: {str(e)}") # Sidebar file upload st.sidebar.title("Upload Documents") uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"]) # Process uploaded file if uploaded_file is not None: file_path = handle_uploaded_file(uploaded_file) if file_path: st.session_state["uploaded_file_path"] = file_path # Add a button to trigger fine-tuning st.sidebar.title("Fine-Tune Model") fine_tune_button = st.sidebar.button("Fine-Tune GPT-2") if fine_tune_button: fine_tune_model()