shyatri commited on
Commit
ba9807d
·
verified ·
1 Parent(s): f48a29c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py: Fine-tune internally + Gradio UI
2
-
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
4
  from datasets import Dataset
5
  import gradio as gr
@@ -36,7 +34,8 @@ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
36
  # ------------------------------
37
  def preprocess(examples):
38
  model_inputs = tokenizer(examples['input'], truncation=True, padding='max_length', max_length=64)
39
- labels = tokenizer(examples['output'], truncation=True, padding='max_length', max_length=64)
 
40
  model_inputs['labels'] = labels['input_ids']
41
  return model_inputs
42
 
@@ -65,11 +64,12 @@ trainer = Trainer(
65
  model=model,
66
  args=training_args,
67
  train_dataset=tokenized_dataset,
68
- tokenizer=None, # avoid FutureWarning
69
  data_collator=data_collator
70
  )
71
 
72
- if not os.path.exists(OUTPUT_DIR):
 
73
  trainer.train()
74
  model.save_pretrained(OUTPUT_DIR)
75
  tokenizer.save_pretrained(OUTPUT_DIR)
@@ -93,13 +93,14 @@ def respond(user_input, chat_history):
93
  return chat_history, chat_history
94
 
95
  with gr.Blocks() as demo:
96
- gr.Markdown("<h1 style='text-align:center'>💊 Health Remedies Chatbot</h1>")
97
  chatbot = gr.Chatbot()
98
- msg = gr.Textbox(placeholder="Type your message here...")
99
- state = gr.State()
100
  with gr.Row():
101
- send = gr.Button("Send")
 
102
  send.click(respond, [msg, state], [chatbot, state])
103
  msg.submit(respond, [msg, state], [chatbot, state])
104
 
105
  demo.launch()
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
2
  from datasets import Dataset
3
  import gradio as gr
 
34
  # ------------------------------
35
  def preprocess(examples):
36
  model_inputs = tokenizer(examples['input'], truncation=True, padding='max_length', max_length=64)
37
+ with tokenizer.as_target_tokenizer():
38
+ labels = tokenizer(examples['output'], truncation=True, padding='max_length', max_length=64)
39
  model_inputs['labels'] = labels['input_ids']
40
  return model_inputs
41
 
 
64
  model=model,
65
  args=training_args,
66
  train_dataset=tokenized_dataset,
67
+ tokenizer=tokenizer, # FIXED
68
  data_collator=data_collator
69
  )
70
 
71
+ # Fine-tune only if directory is missing or empty
72
+ if not os.path.exists(OUTPUT_DIR) or not os.listdir(OUTPUT_DIR):
73
  trainer.train()
74
  model.save_pretrained(OUTPUT_DIR)
75
  tokenizer.save_pretrained(OUTPUT_DIR)
 
93
  return chat_history, chat_history
94
 
95
  with gr.Blocks() as demo:
96
+ gr.Markdown("<h1 style='text-align:center; color:#4CAF50;'>💊 Health Remedies Chatbot</h1>")
97
  chatbot = gr.Chatbot()
98
+ state = gr.State([])
 
99
  with gr.Row():
100
+ msg = gr.Textbox(placeholder="Type your message here...", scale=8)
101
+ send = gr.Button("Send", scale=2)
102
  send.click(respond, [msg, state], [chatbot, state])
103
  msg.submit(respond, [msg, state], [chatbot, state])
104
 
105
  demo.launch()
106
+