| | import gradio as gr |
| | import torch |
| | import os |
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | MODEL_REPO = "ibz18/Model_D_weights" |
| | BASE_MODEL = "csebuetnlp/banglat5" |
| |
|
| | hf_token = os.environ.get("HF_TOKEN") |
| |
|
| | print("1. Downloading .pt file...") |
| | abstracter_rl_path = hf_hub_download( |
| | repo_id=MODEL_REPO, |
| | filename="abstracter_rl.pt", |
| | token=hf_token |
| | ) |
| |
|
| | print("2. Loading tokenizer and base model...") |
| | tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
| | model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL) |
| |
|
| | print("3. Resizing embeddings...") |
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | print("4. Injecting .pt weights into memory...") |
| | checkpoint = torch.load(abstracter_rl_path, map_location="cpu", weights_only=True) |
| |
|
| | if isinstance(checkpoint, dict) and "state_dict" in checkpoint: |
| | state_dict = checkpoint["state_dict"] |
| | else: |
| | state_dict = checkpoint |
| |
|
| | model.load_state_dict(state_dict, strict=False) |
| | model.eval() |
| |
|
| | def generate_summary(text): |
| | if not text.strip(): |
| | return "Please enter Bangla text." |
| | |
| | try: |
| | inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) |
| | |
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=128, |
| | do_sample=False, |
| | num_beams=2, |
| | repetition_penalty=2.5, |
| | early_stopping=True, |
| | decoder_start_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | pad_token_id=tokenizer.pad_token_id |
| | ) |
| | |
| | summary = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| | return summary if summary.strip() else "ERROR: Empty string" |
| |
|
| | except Exception as e: |
| | return f"CRASH ERROR: {str(e)}" |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=generate_summary, |
| | inputs=gr.Textbox(lines=8, label="Input Bangla Text", placeholder="এখানে আপনার বাংলা টেক্সট দিন..."), |
| | outputs=gr.Textbox(label="Generated Summary"), |
| | title="Model_D", |
| | description="Live testing interface for Model_D" |
| | ) |
| |
|
| | demo.launch() |
| |
|