# import gradio as gr # import torch # from transformers import AutoTokenizer, AutoModelForCausalLM # MODEL_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-DPO" # print("Loading model...") # tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # tokenizer.pad_token = tokenizer.eos_token # model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32).eval() # print("Model ready.") # def respond(prompt): # if not prompt.strip(): # return "Please enter a question." # formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" # inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512) # with torch.no_grad(): # out = model.generate( # **inputs, # max_new_tokens=200, # do_sample=False, # pad_token_id=tokenizer.eos_token_id, # eos_token_id=tokenizer.eos_token_id, # ) # return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() # demo = gr.Interface( # fn=respond, # inputs=gr.Textbox(lines=3, placeholder="Ask me anything...", label="Your Question"), # outputs=gr.Textbox(label="Model Response", lines=6), # title="SmolLM-135M (SFT + DPO)", # description="SmolLM-135M fine-tuned on Dolly-15k (SFT) and Orca DPO pairs (DPO). Built for IBA NLP Assignment 04.", # examples=[ # ["What causes seasons on Earth?"], # ["What is the capital of Australia?"], # ["Explain photosynthesis briefly."], # ["Give me 3 tips to reduce plastic waste."], # ], # theme=gr.themes.Soft() # ) # demo.launch(server_name="0.0.0.0", server_port=7860) import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM BASE_ID = "HuggingFaceTB/SmolLM-135M" SFT_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-only" DPO_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-DPO" print("Loading all 3 models...") # Base base_tokenizer = AutoTokenizer.from_pretrained(BASE_ID) base_tokenizer.pad_token = base_tokenizer.eos_token base_model = AutoModelForCausalLM.from_pretrained(BASE_ID, torch_dtype=torch.float32).eval() # SFT sft_tokenizer = AutoTokenizer.from_pretrained(SFT_ID) sft_tokenizer.pad_token = sft_tokenizer.eos_token sft_model = AutoModelForCausalLM.from_pretrained(SFT_ID, torch_dtype=torch.float32).eval() # SFT+DPO dpo_tokenizer = AutoTokenizer.from_pretrained(DPO_ID) dpo_tokenizer.pad_token = dpo_tokenizer.eos_token dpo_model = AutoModelForCausalLM.from_pretrained(DPO_ID, torch_dtype=torch.float32).eval() print("All models loaded.") def generate(model, tokenizer, prompt, template="dpo"): if template == "base": formatted = prompt elif template == "sft": formatted = f"### Instruction:\n{prompt}\n\n### Response:\n" else: formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=150, do_sample=False, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() def compare(prompt): if not prompt.strip(): return "", "", "" base_out = generate(base_model, base_tokenizer, prompt, template="base") sft_out = generate(sft_model, sft_tokenizer, prompt, template="sft") dpo_out = generate(dpo_model, dpo_tokenizer, prompt, template="dpo") return base_out, sft_out, dpo_out with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# SmolLM-135M: Base vs SFT vs SFT+DPO") gr.Markdown("Compare outputs across the full fine-tuning pipeline. Built for IBA NLP Assignment 04.") prompt_box = gr.Textbox(lines=2, placeholder="Ask a question...", label="Your Prompt") btn = gr.Button("Compare All 3 Models", variant="primary") with gr.Row(): out_base = gr.Textbox(label="Base Model (no tuning)", lines=6) out_sft = gr.Textbox(label="SFT Only (Trial 3)", lines=6) out_dpo = gr.Textbox(label="SFT + DPO (Final)", lines=6) gr.Examples( examples=[ ["What causes seasons on Earth?"], ["What is the capital of Australia?"], ["Explain photosynthesis briefly."], ["Give me 3 tips to reduce plastic waste."], ["What are the planets in our solar system?"], ], inputs=prompt_box ) btn.click(fn=compare, inputs=prompt_box, outputs=[out_base, out_sft, out_dpo]) demo.launch(server_name="0.0.0.0", server_port=7860)