Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
# app.py: Fine-tune internally + Gradio UI
|
| 2 |
-
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
|
| 4 |
from datasets import Dataset
|
| 5 |
import gradio as gr
|
|
@@ -36,7 +34,8 @@ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
|
| 36 |
# ------------------------------
|
| 37 |
def preprocess(examples):
|
| 38 |
model_inputs = tokenizer(examples['input'], truncation=True, padding='max_length', max_length=64)
|
| 39 |
-
|
|
|
|
| 40 |
model_inputs['labels'] = labels['input_ids']
|
| 41 |
return model_inputs
|
| 42 |
|
|
@@ -65,11 +64,12 @@ trainer = Trainer(
|
|
| 65 |
model=model,
|
| 66 |
args=training_args,
|
| 67 |
train_dataset=tokenized_dataset,
|
| 68 |
-
tokenizer=
|
| 69 |
data_collator=data_collator
|
| 70 |
)
|
| 71 |
|
| 72 |
-
if
|
|
|
|
| 73 |
trainer.train()
|
| 74 |
model.save_pretrained(OUTPUT_DIR)
|
| 75 |
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
@@ -93,13 +93,14 @@ def respond(user_input, chat_history):
|
|
| 93 |
return chat_history, chat_history
|
| 94 |
|
| 95 |
with gr.Blocks() as demo:
|
| 96 |
-
gr.Markdown("<h1 style='text-align:center'>💊 Health Remedies Chatbot</h1>")
|
| 97 |
chatbot = gr.Chatbot()
|
| 98 |
-
|
| 99 |
-
state = gr.State()
|
| 100 |
with gr.Row():
|
| 101 |
-
|
|
|
|
| 102 |
send.click(respond, [msg, state], [chatbot, state])
|
| 103 |
msg.submit(respond, [msg, state], [chatbot, state])
|
| 104 |
|
| 105 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq
|
| 2 |
from datasets import Dataset
|
| 3 |
import gradio as gr
|
|
|
|
| 34 |
# ------------------------------
|
| 35 |
def preprocess(examples):
|
| 36 |
model_inputs = tokenizer(examples['input'], truncation=True, padding='max_length', max_length=64)
|
| 37 |
+
with tokenizer.as_target_tokenizer():
|
| 38 |
+
labels = tokenizer(examples['output'], truncation=True, padding='max_length', max_length=64)
|
| 39 |
model_inputs['labels'] = labels['input_ids']
|
| 40 |
return model_inputs
|
| 41 |
|
|
|
|
| 64 |
model=model,
|
| 65 |
args=training_args,
|
| 66 |
train_dataset=tokenized_dataset,
|
| 67 |
+
tokenizer=tokenizer, # FIXED
|
| 68 |
data_collator=data_collator
|
| 69 |
)
|
| 70 |
|
| 71 |
+
# Fine-tune only if directory is missing or empty
|
| 72 |
+
if not os.path.exists(OUTPUT_DIR) or not os.listdir(OUTPUT_DIR):
|
| 73 |
trainer.train()
|
| 74 |
model.save_pretrained(OUTPUT_DIR)
|
| 75 |
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
|
|
| 93 |
return chat_history, chat_history
|
| 94 |
|
| 95 |
with gr.Blocks() as demo:
|
| 96 |
+
gr.Markdown("<h1 style='text-align:center; color:#4CAF50;'>💊 Health Remedies Chatbot</h1>")
|
| 97 |
chatbot = gr.Chatbot()
|
| 98 |
+
state = gr.State([])
|
|
|
|
| 99 |
with gr.Row():
|
| 100 |
+
msg = gr.Textbox(placeholder="Type your message here...", scale=8)
|
| 101 |
+
send = gr.Button("Send", scale=2)
|
| 102 |
send.click(respond, [msg, state], [chatbot, state])
|
| 103 |
msg.submit(respond, [msg, state], [chatbot, state])
|
| 104 |
|
| 105 |
demo.launch()
|
| 106 |
+
|