| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
| |
| MODEL_ID = "ibz18/sft" |
|
|
| print("Downloading and loading the SFT model...") |
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) |
|
|
| model.eval() |
|
|
| def generate_summary(text): |
| if not text.strip(): |
| return "Please enter some Bangla text." |
| |
| try: |
| |
| |
| |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
| |
| with torch.no_grad(): |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=128, |
| do_sample=False, |
| num_beams=4, |
| repetition_penalty=2.5, |
| early_stopping=True, |
| decoder_start_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
| |
| summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| |
| if not summary or summary.isspace(): |
| return "ERROR: Model generated an empty string." |
| |
| return summary |
|
|
| except Exception as e: |
| return f"CRASH ERROR: {str(e)}" |
|
|
| |
| demo = gr.Interface( |
| fn=generate_summary, |
| inputs=gr.Textbox(lines=8, label="Input Bangla Text", placeholder="এখানে আপনার বাংলা টেক্সট দিন..."), |
| outputs=gr.Textbox(label="Generated Summary"), |
| title="SFT Baseline Model", |
| description="Live testing interface for the Supervised Fine-Tuned (SFT) model." |
| ) |
|
|
| demo.launch() |
|
|