fuzzylab / app.py
odaly's picture
Update app.py
22226e6 verified
import os
import time
import re
import json
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, TextDataset, DataCollatorForLanguageModeling
from streamlit_chat import message
from pathlib import Path
import torch
from PyPDF2 import PdfReader
import requests
from bs4 import BeautifulSoup
# Set page title and icon
st.set_page_config(page_title="GPT-2 Text Uploader and Trainer", page_icon=":robot_face:")
# Custom CSS for styling chat messages and buttons
st.markdown(
"""
<style>
.stButton>button {
background-color: #4CAF50;
color: white;
border-radius: 12px;
padding: 10px 24px;
}
.stTextArea textarea {
background-color: #f5f5f5;
}
.stDownloadButton>button {
background-color: #4CAF50;
color: white;
}
.stMessageContainer {
border-radius: 15px;
padding: 10px;
margin: 10px 0;
}
.stMessage--user {
background-color: #dfe7f3;
border-left: 6px solid #006699;
}
.stMessage--assistant {
background-color: #f3f3f3;
border-left: 6px solid #4CAF50;
}
pre {
background-color: #f5f5f5;
border-left: 6px solid #dfe7f3;
padding: 10px;
font-size: 14px;
border-radius: 8px;
}
</style>
""",
unsafe_allow_html=True,
)
# Initialize session state variables
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "past" not in st.session_state:
st.session_state["past"] = []
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
if "chat_data" not in st.session_state:
st.session_state["chat_data"] = [] # For storing the chat logs
if "uploaded_docs" not in st.session_state:
st.session_state["uploaded_docs"] = [] # For storing uploaded document content
if "web_data" not in st.session_state:
st.session_state["web_data"] = [] # For storing web scraped data
# Sidebar - Model Selection, Style Parameters, and Cost Display
st.sidebar.title("Model Selection")
model_name = "gpt2"
# Parameters to adjust the response style and creativity
st.sidebar.title("Response Style Controls")
temperature = st.sidebar.slider("Creativity (Temperature)", min_value=0.0, max_value=1.5, value=0.7, step=0.1)
top_p = st.sidebar.slider("Nucleus Sampling (Top-p)", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
top_k = st.sidebar.slider("Token Sampling (Top-k)", min_value=1, max_value=100, value=50, step=1)
repetition_penalty = st.sidebar.slider("Repetition Penalty", min_value=1.0, max_value=2.0, value=1.2, step=0.1)
max_length = st.sidebar.slider("Max Length", min_value=100, max_value=1024, value=800, step=10)
@st.cache_resource
def load_model_and_tokenizer():
model_path = "gpt2" # Path to the local model directory
tokenizer = AutoTokenizer.from_pretrained("gpt2", clean_up_tokenization_spaces=True)
model = AutoModelForCausalLM.from_pretrained(model_path)
return tokenizer, model
tokenizer, model = load_model_and_tokenizer()
def generate_response(prompt):
"""
Generate a response using the GPT-2 model, including document and web data context.
"""
context = " ".join(st.session_state['uploaded_docs']) + " " + " ".join(st.session_state['web_data']) + "\n" + prompt
inputs = tokenizer(context, return_tensors="pt")
generation_config = {
"max_length": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"pad_token_id": tokenizer.eos_token_id,
"do_sample": True # Always sample tokens
}
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
**generation_config
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Reset session
def reset_session():
""" Reset all session state variables. """
st.session_state["generated"] = []
st.session_state["past"] = []
st.session_state["messages"] = [{"role": "system", "content": "You are a helpful assistant."}]
st.session_state["chat_data"] = [] # Reset chat logs
st.session_state["uploaded_docs"] = [] # Reset uploaded docs
st.session_state["web_data"] = [] # Reset web data
reset_button = st.sidebar.button("Reset Chat")
if reset_button:
reset_session()
def save_chat_data(chat_data):
""" Save chat logs for future fine-tuning or reference. """
with open("chat_data.json", "w") as f:
json.dump(chat_data, f, indent=4)
def handle_uploaded_file(uploaded_file):
dataset_dir = "./datasets"
dataset_path = Path(dataset_dir) / f"{uploaded_file.name}.txt"
# Check if the file is a PDF
if uploaded_file.type == "application/pdf":
pdf_reader = PdfReader(uploaded_file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
if not text:
st.error("Failed to extract text from the PDF.")
return None # Return None if text extraction fails
with open(dataset_path, "w") as f:
f.write(text)
st.success(f"{uploaded_file.name} uploaded successfully as {dataset_path}")
else:
with open(dataset_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"File saved to {dataset_path}")
return str(dataset_path) # Return the path to the saved file
def handle_web_link(url):
""" Fetch and scrape text content from a website. """
try:
response = requests.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
text = soup.get_text()
st.session_state["web_data"].append(text)
st.success(f"Content from {url} saved successfully!")
except requests.exceptions.RequestException as e:
st.error(f"Failed to retrieve content: {e}")
st.sidebar.title("Add Website Links")
web_link = st.sidebar.text_input("Enter Website URL")
if web_link:
handle_web_link(web_link)
# Chat interface
response_container = st.container()
container = st.container()
with container:
with st.form(key="user_input_form"):
user_input = st.text_area("You:", key="user_input", height=100)
submit_button = st.form_submit_button("Send")
if submit_button and user_input:
start_time = time.time()
output = generate_response(user_input)
inference_time = time.time() - start_time
st.session_state["past"].append(user_input)
st.session_state["generated"].append(output)
# Log chat data for future training
st.session_state["chat_data"].append(
{"user_input": user_input, "model_response": output}
)
save_chat_data(st.session_state["chat_data"])
with response_container:
for i in range(len(st.session_state["generated"])):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
message(st.session_state["generated"][i], key=str(i))
def fine_tune_model():
uploaded_file_path = st.session_state.get("uploaded_file_path", "")
if not uploaded_file_path:
st.warning("Please upload a text or PDF dataset to fine-tune the model.")
return
# Prepare dataset for fine-tuning (using the uploaded .txt file)
try:
with open(uploaded_file_path, "r") as f:
text = f.read().strip() # Ensure that the file is not empty
if len(text) == 0:
raise ValueError("The dataset is empty.")
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path=uploaded_file_path, # Ensure this path is a .txt file
block_size=128,
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# Define training arguments
training_args = TrainingArguments(
output_dir="./gpt2-finetuned",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=8,
save_steps=10_000,
save_total_limit=2,
logging_dir="./logs",
logging_steps=200,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
# Fine-tune the model
trainer.train()
st.success("Model fine-tuning completed successfully.")
except Exception as e:
st.error(f"Error during fine-tuning: {str(e)}")
# Sidebar file upload
st.sidebar.title("Upload Documents")
uploaded_file = st.sidebar.file_uploader("Choose a file", type=["txt", "pdf"])
# Process uploaded file
if uploaded_file is not None:
file_path = handle_uploaded_file(uploaded_file)
if file_path:
st.session_state["uploaded_file_path"] = file_path
# Add a button to trigger fine-tuning
st.sidebar.title("Fine-Tune Model")
fine_tune_button = st.sidebar.button("Fine-Tune GPT-2")
if fine_tune_button:
fine_tune_model()