SmolLM-chat / app.py
AREEBAFATIMA12's picture
Update app.py
8a821d4 verified
# import gradio as gr
# import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM
# MODEL_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-DPO"
# print("Loading model...")
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# tokenizer.pad_token = tokenizer.eos_token
# model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32).eval()
# print("Model ready.")
# def respond(prompt):
# if not prompt.strip():
# return "Please enter a question."
# formatted = f"<|user|>\n{prompt}</s>\n<|assistant|>\n"
# inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512)
# with torch.no_grad():
# out = model.generate(
# **inputs,
# max_new_tokens=200,
# do_sample=False,
# pad_token_id=tokenizer.eos_token_id,
# eos_token_id=tokenizer.eos_token_id,
# )
# return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
# demo = gr.Interface(
# fn=respond,
# inputs=gr.Textbox(lines=3, placeholder="Ask me anything...", label="Your Question"),
# outputs=gr.Textbox(label="Model Response", lines=6),
# title="SmolLM-135M (SFT + DPO)",
# description="SmolLM-135M fine-tuned on Dolly-15k (SFT) and Orca DPO pairs (DPO). Built for IBA NLP Assignment 04.",
# examples=[
# ["What causes seasons on Earth?"],
# ["What is the capital of Australia?"],
# ["Explain photosynthesis briefly."],
# ["Give me 3 tips to reduce plastic waste."],
# ],
# theme=gr.themes.Soft()
# )
# demo.launch(server_name="0.0.0.0", server_port=7860)
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
BASE_ID = "HuggingFaceTB/SmolLM-135M"
SFT_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-only"
DPO_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-DPO"
print("Loading all 3 models...")
# Base
base_tokenizer = AutoTokenizer.from_pretrained(BASE_ID)
base_tokenizer.pad_token = base_tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(BASE_ID, torch_dtype=torch.float32).eval()
# SFT
sft_tokenizer = AutoTokenizer.from_pretrained(SFT_ID)
sft_tokenizer.pad_token = sft_tokenizer.eos_token
sft_model = AutoModelForCausalLM.from_pretrained(SFT_ID, torch_dtype=torch.float32).eval()
# SFT+DPO
dpo_tokenizer = AutoTokenizer.from_pretrained(DPO_ID)
dpo_tokenizer.pad_token = dpo_tokenizer.eos_token
dpo_model = AutoModelForCausalLM.from_pretrained(DPO_ID, torch_dtype=torch.float32).eval()
print("All models loaded.")
def generate(model, tokenizer, prompt, template="dpo"):
if template == "base":
formatted = prompt
elif template == "sft":
formatted = f"### Instruction:\n{prompt}\n\n### Response:\n"
else:
formatted = f"<|user|>\n{prompt}</s>\n<|assistant|>\n"
inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
def compare(prompt):
if not prompt.strip():
return "", "", ""
base_out = generate(base_model, base_tokenizer, prompt, template="base")
sft_out = generate(sft_model, sft_tokenizer, prompt, template="sft")
dpo_out = generate(dpo_model, dpo_tokenizer, prompt, template="dpo")
return base_out, sft_out, dpo_out
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# SmolLM-135M: Base vs SFT vs SFT+DPO")
gr.Markdown("Compare outputs across the full fine-tuning pipeline. Built for IBA NLP Assignment 04.")
prompt_box = gr.Textbox(lines=2, placeholder="Ask a question...", label="Your Prompt")
btn = gr.Button("Compare All 3 Models", variant="primary")
with gr.Row():
out_base = gr.Textbox(label="Base Model (no tuning)", lines=6)
out_sft = gr.Textbox(label="SFT Only (Trial 3)", lines=6)
out_dpo = gr.Textbox(label="SFT + DPO (Final)", lines=6)
gr.Examples(
examples=[
["What causes seasons on Earth?"],
["What is the capital of Australia?"],
["Explain photosynthesis briefly."],
["Give me 3 tips to reduce plastic waste."],
["What are the planets in our solar system?"],
],
inputs=prompt_box
)
btn.click(fn=compare, inputs=prompt_box, outputs=[out_base, out_sft, out_dpo])
demo.launch(server_name="0.0.0.0", server_port=7860)