Spaces:
Running
Running
| # import gradio as gr | |
| # import torch | |
| # from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # MODEL_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-DPO" | |
| # print("Loading model...") | |
| # tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # tokenizer.pad_token = tokenizer.eos_token | |
| # model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32).eval() | |
| # print("Model ready.") | |
| # def respond(prompt): | |
| # if not prompt.strip(): | |
| # return "Please enter a question." | |
| # formatted = f"<|user|>\n{prompt}</s>\n<|assistant|>\n" | |
| # inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512) | |
| # with torch.no_grad(): | |
| # out = model.generate( | |
| # **inputs, | |
| # max_new_tokens=200, | |
| # do_sample=False, | |
| # pad_token_id=tokenizer.eos_token_id, | |
| # eos_token_id=tokenizer.eos_token_id, | |
| # ) | |
| # return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() | |
| # demo = gr.Interface( | |
| # fn=respond, | |
| # inputs=gr.Textbox(lines=3, placeholder="Ask me anything...", label="Your Question"), | |
| # outputs=gr.Textbox(label="Model Response", lines=6), | |
| # title="SmolLM-135M (SFT + DPO)", | |
| # description="SmolLM-135M fine-tuned on Dolly-15k (SFT) and Orca DPO pairs (DPO). Built for IBA NLP Assignment 04.", | |
| # examples=[ | |
| # ["What causes seasons on Earth?"], | |
| # ["What is the capital of Australia?"], | |
| # ["Explain photosynthesis briefly."], | |
| # ["Give me 3 tips to reduce plastic waste."], | |
| # ], | |
| # theme=gr.themes.Soft() | |
| # ) | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| BASE_ID = "HuggingFaceTB/SmolLM-135M" | |
| SFT_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-only" | |
| DPO_ID = "AREEBAFATIMA12/SmolLM-135M-SFT-DPO" | |
| print("Loading all 3 models...") | |
| # Base | |
| base_tokenizer = AutoTokenizer.from_pretrained(BASE_ID) | |
| base_tokenizer.pad_token = base_tokenizer.eos_token | |
| base_model = AutoModelForCausalLM.from_pretrained(BASE_ID, torch_dtype=torch.float32).eval() | |
| # SFT | |
| sft_tokenizer = AutoTokenizer.from_pretrained(SFT_ID) | |
| sft_tokenizer.pad_token = sft_tokenizer.eos_token | |
| sft_model = AutoModelForCausalLM.from_pretrained(SFT_ID, torch_dtype=torch.float32).eval() | |
| # SFT+DPO | |
| dpo_tokenizer = AutoTokenizer.from_pretrained(DPO_ID) | |
| dpo_tokenizer.pad_token = dpo_tokenizer.eos_token | |
| dpo_model = AutoModelForCausalLM.from_pretrained(DPO_ID, torch_dtype=torch.float32).eval() | |
| print("All models loaded.") | |
| def generate(model, tokenizer, prompt, template="dpo"): | |
| if template == "base": | |
| formatted = prompt | |
| elif template == "sft": | |
| formatted = f"### Instruction:\n{prompt}\n\n### Response:\n" | |
| else: | |
| formatted = f"<|user|>\n{prompt}</s>\n<|assistant|>\n" | |
| inputs = tokenizer(formatted, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip() | |
| def compare(prompt): | |
| if not prompt.strip(): | |
| return "", "", "" | |
| base_out = generate(base_model, base_tokenizer, prompt, template="base") | |
| sft_out = generate(sft_model, sft_tokenizer, prompt, template="sft") | |
| dpo_out = generate(dpo_model, dpo_tokenizer, prompt, template="dpo") | |
| return base_out, sft_out, dpo_out | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# SmolLM-135M: Base vs SFT vs SFT+DPO") | |
| gr.Markdown("Compare outputs across the full fine-tuning pipeline. Built for IBA NLP Assignment 04.") | |
| prompt_box = gr.Textbox(lines=2, placeholder="Ask a question...", label="Your Prompt") | |
| btn = gr.Button("Compare All 3 Models", variant="primary") | |
| with gr.Row(): | |
| out_base = gr.Textbox(label="Base Model (no tuning)", lines=6) | |
| out_sft = gr.Textbox(label="SFT Only (Trial 3)", lines=6) | |
| out_dpo = gr.Textbox(label="SFT + DPO (Final)", lines=6) | |
| gr.Examples( | |
| examples=[ | |
| ["What causes seasons on Earth?"], | |
| ["What is the capital of Australia?"], | |
| ["Explain photosynthesis briefly."], | |
| ["Give me 3 tips to reduce plastic waste."], | |
| ["What are the planets in our solar system?"], | |
| ], | |
| inputs=prompt_box | |
| ) | |
| btn.click(fn=compare, inputs=prompt_box, outputs=[out_base, out_sft, out_dpo]) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |