shyatri commited on
Commit
f48a29c
·
verified ·
1 Parent(s): 380074a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -27
app.py CHANGED
@@ -1,15 +1,26 @@
1
- # Full script: Fine-tune + Gradio UI
2
 
3
- from datasets import load_dataset
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
 
5
  import gradio as gr
6
  import torch
7
  import os
8
 
9
  # ------------------------------
10
- # 1. Dataset
11
  # ------------------------------
12
- dataset = load_dataset('csv', data_files={'train': 'remedies.csv'}, split='train')
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # ------------------------------
15
  # 2. Model & Tokenizer
@@ -24,10 +35,8 @@ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
24
  # 3. Preprocess
25
  # ------------------------------
26
  def preprocess(examples):
27
- inputs = examples['input']
28
- targets = examples['output']
29
- model_inputs = tokenizer(inputs, truncation=True, padding='max_length', max_length=64)
30
- labels = tokenizer(targets, truncation=True, padding='max_length', max_length=64)
31
  model_inputs['labels'] = labels['input_ids']
32
  return model_inputs
33
 
@@ -45,47 +54,52 @@ training_args = TrainingArguments(
45
  output_dir=OUTPUT_DIR,
46
  per_device_train_batch_size=4,
47
  num_train_epochs=3,
48
- logging_steps=10,
49
  save_steps=50,
50
  save_total_limit=2,
51
  fp16=torch.cuda.is_available(),
52
- report_to="none",
53
  )
54
 
55
  trainer = Trainer(
56
  model=model,
57
  args=training_args,
58
  train_dataset=tokenized_dataset,
59
- tokenizer=tokenizer,
60
  data_collator=data_collator
61
  )
62
 
63
- # Fine-tune only if model is not already saved
64
  if not os.path.exists(OUTPUT_DIR):
65
  trainer.train()
66
- trainer.save_model(OUTPUT_DIR)
 
67
 
68
  # ------------------------------
69
- # 6. Load fine-tuned model for inference
70
  # ------------------------------
71
  model = AutoModelForSeq2SeqLM.from_pretrained(OUTPUT_DIR)
72
  tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
73
 
74
- def respond(user_input):
 
 
 
75
  inputs = tokenizer(user_input, return_tensors="pt")
76
  outputs = model.generate(**inputs, max_length=64)
77
  reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
78
- return reply
 
 
 
79
 
80
- # ------------------------------
81
- # 7. Gradio UI
82
- # ------------------------------
83
- ui = gr.Interface(
84
- fn=respond,
85
- inputs="text",
86
- outputs="text",
87
- title="Health Remedies Chatbot",
88
- description="Ask health-related questions and get remedy suggestions!"
89
- )
90
 
91
- ui.launch()
 
1
+ # app.py: Fine-tune internally + Gradio UI
2
 
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
4
+ from datasets import Dataset
5
  import gradio as gr
6
  import torch
7
  import os
8
 
9
  # ------------------------------
10
+ # 1. Internal Dataset
11
  # ------------------------------
12
+ data = [
13
+ {"input": "What should I do for a cold?", "output": "I'm sorry you're feeling unwell. Drink warm water, get rest, and consider vitamin C."},
14
+ {"input": "What to do if I have a headache?", "output": "I understand headaches are frustrating. Try meditation, rest, and stay hydrated."},
15
+ {"input": "My child has fever, what do I do?", "output": "Give paracetamol, keep them hydrated, and if fever persists, consult a doctor."},
16
+ {"input": "Who can I contact for fever treatment?", "output": "You can reach Dr. Ankit Verma at +91-9876543210 or Dr. Priya Singh at +91-9123456780."},
17
+ {"input": "I feel dizzy, what should I do?", "output": "Sit down, drink water, and rest. If it continues, see a doctor."},
18
+ {"input": "I am anxious and need help.", "output": "Feeling anxious is okay. Try deep breathing. You can also speak with Dr. Richa Nair at +91-9874455667."},
19
+ {"input": "I have mild back pain.", "output": "Gentle stretching and rest can help. For consultation, Dr. Amit Khanna +91-9988774455 is available."},
20
+ {"input": "My child has a cough and cold.", "output": "Dr. Sneha Kapoor at +91-9871122334 and Dr. Arjun Mehta at +91-9112233445 can assist. Keep your child warm."}
21
+ ]
22
+
23
+ dataset = Dataset.from_list(data)
24
 
25
  # ------------------------------
26
  # 2. Model & Tokenizer
 
35
  # 3. Preprocess
36
  # ------------------------------
37
  def preprocess(examples):
38
+ model_inputs = tokenizer(examples['input'], truncation=True, padding='max_length', max_length=64)
39
+ labels = tokenizer(examples['output'], truncation=True, padding='max_length', max_length=64)
 
 
40
  model_inputs['labels'] = labels['input_ids']
41
  return model_inputs
42
 
 
54
  output_dir=OUTPUT_DIR,
55
  per_device_train_batch_size=4,
56
  num_train_epochs=3,
57
+ logging_steps=5,
58
  save_steps=50,
59
  save_total_limit=2,
60
  fp16=torch.cuda.is_available(),
61
+ report_to="none"
62
  )
63
 
64
  trainer = Trainer(
65
  model=model,
66
  args=training_args,
67
  train_dataset=tokenized_dataset,
68
+ tokenizer=None, # avoid FutureWarning
69
  data_collator=data_collator
70
  )
71
 
 
72
  if not os.path.exists(OUTPUT_DIR):
73
  trainer.train()
74
+ model.save_pretrained(OUTPUT_DIR)
75
+ tokenizer.save_pretrained(OUTPUT_DIR)
76
 
77
  # ------------------------------
78
+ # 6. Load fine-tuned model
79
  # ------------------------------
80
  model = AutoModelForSeq2SeqLM.from_pretrained(OUTPUT_DIR)
81
  tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)
82
 
83
+ # ------------------------------
84
+ # 7. Gradio UI (chat-style)
85
+ # ------------------------------
86
+ def respond(user_input, chat_history):
87
  inputs = tokenizer(user_input, return_tensors="pt")
88
  outputs = model.generate(**inputs, max_length=64)
89
  reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+ chat_history = chat_history or []
91
+ chat_history.append(("You", user_input))
92
+ chat_history.append(("Bot", reply))
93
+ return chat_history, chat_history
94
 
95
+ with gr.Blocks() as demo:
96
+ gr.Markdown("<h1 style='text-align:center'>💊 Health Remedies Chatbot</h1>")
97
+ chatbot = gr.Chatbot()
98
+ msg = gr.Textbox(placeholder="Type your message here...")
99
+ state = gr.State()
100
+ with gr.Row():
101
+ send = gr.Button("Send")
102
+ send.click(respond, [msg, state], [chatbot, state])
103
+ msg.submit(respond, [msg, state], [chatbot, state])
 
104
 
105
+ demo.launch()