File size: 3,567 Bytes
bef88ce
d9356a9
bef88ce
 
 
d9356a9
bef88ce
 
d9356a9
bef88ce
d9356a9
 
f48a29c
d9356a9
 
 
 
f48a29c
bef88ce
 
 
 
 
 
 
 
 
 
 
 
 
 
f48a29c
ba9807d
 
bef88ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f48a29c
bef88ce
 
 
f48a29c
bef88ce
 
 
 
 
 
d9356a9
bef88ce
 
 
ba9807d
 
bef88ce
f48a29c
 
bef88ce
 
f48a29c
bef88ce
 
 
 
f48a29c
d9356a9
f48a29c
 
bef88ce
 
 
f48a29c
 
 
 
bef88ce
d9356a9
 
 
 
 
 
 
f48a29c
ba9807d
f48a29c
ba9807d
 
f48a29c
 
bef88ce
f48a29c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset, load_dataset
import gradio as gr
import torch
import os
import pandas as pd

# ------------------------------
# 1. Load dataset from CSV
# ------------------------------
CSV_FILE = "remedies.csv"  # path to your CSV
df = pd.read_csv(CSV_FILE)

# Prepare dataset in HuggingFace format
data = []
for _, row in df.iterrows():
    data.append({"input": row['symptoms'], "output": row['response']})
dataset = Dataset.from_list(data)

# ------------------------------
# 2. Model & Tokenizer
# ------------------------------
MODEL_NAME = "google/flan-t5-small"
OUTPUT_DIR = "./finetuned-bot"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# ------------------------------
# 3. Preprocess
# ------------------------------
def preprocess(examples):
    model_inputs = tokenizer(examples['input'], truncation=True, padding='max_length', max_length=64)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['output'], truncation=True, padding='max_length', max_length=64)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_dataset = dataset.map(preprocess, batched=True)

# ------------------------------
# 4. Data collator
# ------------------------------
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# ------------------------------
# 5. Training
# ------------------------------
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=4,
    num_train_epochs=3,
    logging_steps=5,
    save_steps=50,
    save_total_limit=2,
    fp16=torch.cuda.is_available(),
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

# Fine-tune only if directory is missing or empty
if not os.path.exists(OUTPUT_DIR) or not os.listdir(OUTPUT_DIR):
    trainer.train()
    model.save_pretrained(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)

# ------------------------------
# 6. Load fine-tuned model
# ------------------------------
model = AutoModelForSeq2SeqLM.from_pretrained(OUTPUT_DIR)
tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)

# ------------------------------
# 7. Gradio UI (blue & white)
# ------------------------------
def respond(user_input, chat_history):
    inputs = tokenizer(user_input, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=64)
    reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
    chat_history = chat_history or []
    chat_history.append(("You", user_input))
    chat_history.append(("Bot", reply))
    return chat_history, chat_history

with gr.Blocks(css="""
    body {background-color: #f0f8ff;}
    .gradio-container {border-radius: 15px; padding: 20px; background-color: #ffffff;}
    .chatbot-message.user {background-color: #cce5ff; color: #000;}
    .chatbot-message.bot {background-color: #e6f0ff; color: #000;}
""") as demo:
    gr.Markdown("<h1 style='text-align:center; color:#007BFF;'>💊 Health Remedies Chatbot</h1>")
    chatbot = gr.Chatbot()
    state = gr.State([])
    with gr.Row():
        msg = gr.Textbox(placeholder="Type your message here...", scale=8)
        send = gr.Button("Send", scale=2)
    send.click(respond, [msg, state], [chatbot, state])
    msg.submit(respond, [msg, state], [chatbot, state])

demo.launch()