wellBeingBot / app.py
shyatri's picture
Update app.py
d9356a9 verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, load_dataset
import gradio as gr
import torch
import os
import pandas as pd
# ------------------------------
# 1. Load dataset from CSV
# ------------------------------
CSV_FILE = "remedies.csv" # path to your CSV
df = pd.read_csv(CSV_FILE)
# Prepare dataset in HuggingFace format
data = []
for _, row in df.iterrows():
data.append({"input": row['symptoms'], "output": row['response']})
dataset = Dataset.from_list(data)
# ------------------------------
# 2. Model & Tokenizer
# ------------------------------
MODEL_NAME = "google/flan-t5-small"
OUTPUT_DIR = "./finetuned-bot"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# ------------------------------
# 3. Preprocess
# ------------------------------
def preprocess(examples):
model_inputs = tokenizer(examples['input'], truncation=True, padding='max_length', max_length=64)
with tokenizer.as_target_tokenizer():
labels = tokenizer(examples['output'], truncation=True, padding='max_length', max_length=64)
model_inputs['labels'] = labels['input_ids']
return model_inputs
tokenized_dataset = dataset.map(preprocess, batched=True)
# ------------------------------
# 4. Data collator
# ------------------------------
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
# ------------------------------
# 5. Training
# ------------------------------
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
per_device_train_batch_size=4,
num_train_epochs=3,
logging_steps=5,
save_steps=50,
save_total_limit=2,
fp16=torch.cuda.is_available(),
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
tokenizer=tokenizer,
data_collator=data_collator
)
# Fine-tune only if directory is missing or empty
if not os.path.exists(OUTPUT_DIR) or not os.listdir(OUTPUT_DIR):
trainer.train()
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
# ------------------------------
# 6. Load fine-tuned model
# ------------------------------
model = AutoModelForSeq2SeqLM.from_pretrained(OUTPUT_DIR)
tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
# ------------------------------
# 7. Gradio UI (blue & white)
# ------------------------------
def respond(user_input, chat_history):
inputs = tokenizer(user_input, return_tensors="pt")
outputs = model.generate(**inputs, max_length=64)
reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
chat_history = chat_history or []
chat_history.append(("You", user_input))
chat_history.append(("Bot", reply))
return chat_history, chat_history
with gr.Blocks(css="""
body {background-color: #f0f8ff;}
.gradio-container {border-radius: 15px; padding: 20px; background-color: #ffffff;}
.chatbot-message.user {background-color: #cce5ff; color: #000;}
.chatbot-message.bot {background-color: #e6f0ff; color: #000;}
""") as demo:
gr.Markdown("<h1 style='text-align:center; color:#007BFF;'>๐Ÿ’Š Health Remedies Chatbot</h1>")
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row():
msg = gr.Textbox(placeholder="Type your message here...", scale=8)
send = gr.Button("Send", scale=2)
send.click(respond, [msg, state], [chatbot, state])
msg.submit(respond, [msg, state], [chatbot, state])
demo.launch()