Punctuation / app.py
sakibzaman's picture
update
257b61e verified
# app.py
import os
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig
)
import torch
import gradio as gr
from huggingface_hub import login
# --- Step 1: Use correct secret name ---
HF_TOKEN = os.environ.get("HF_TOKEN") # ✅ Must be HF_TOKEN, not HF_API_TOKEN
if not HF_TOKEN:
raise ValueError("❌ HF_TOKEN not set! Go to Settings > Secrets and add your Hugging Face token as 'HF_TOKEN'.")
login(token=HF_TOKEN)
# --- Step 2: Optimized 4-bit config ---
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
# No CPU offloading
)
# --- Step 3: Model ID ---
model_id = "sakibzaman/llama3.2-3b-bengali-punc-adapter"
try:
print("🔍 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("🧠 Loading model with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="balanced", # ← Better than "auto" for multi-GPU or limited VRAM
token=HF_TOKEN,
torch_dtype=torch.float16,
trust_remote_code=False
)
print("✅ Model loaded successfully!")
except Exception as e:
raise RuntimeError(f"❌ Failed to load model: {e}")
# --- Step 4: Inference Function ---
def punctuate_text(input_text):
if not input_text.strip():
return "⚠️ Please enter some Bengali text."
instruction = "Add appropriate punctuation to the following Bengali text to make it grammatically correct and readable."
prompt = f"[INST] {instruction}\n\n{input_text.strip()} [/INST]"
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=64, # Keep short for speed
temperature=0.1,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1
)
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the generated part
if "[/INST]" in full_output:
result = full_output.split("[/INST]", 1)[1].strip()
else:
result = full_output[len(prompt):].strip()
# Clean up
result = result.replace("<|eot_id|>", "").replace("<|end_of_text|>", "").strip()
return result
except Exception as e:
return f"❌ Error during generation: {str(e)}"
# --- Step 5: Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft(), title="🇧🇩 Bengali Punctuation") as demo:
gr.HTML("""
<h1 style="text-align: center;">🇧🇩 Bengali Punctuation Assistant</h1>
<p style="text-align: center;">Restore punctuation using Llama-3.2-3B (fine-tuned)</p>
""")
input_box = gr.Textbox(
label="📝 Input Bengali Text",
placeholder="আমি ঢাকায় থাকি আমার বাড়ি সিলেটে",
lines=4
)
btn = gr.Button("🔤 Add Punctuation", variant="primary")
output_box = gr.Textbox(label="✅ Output", lines=4, interactive=False)
# Examples - just for filling input, NO auto-execution
gr.Examples(
examples=[
"আপনি কেমন আছেন",
"আমি ঢাকায় থাকি আমার বাড়ি সিলেটে",
"তুমি কখন আসবে আমি অপেক্ষা করছি"
],
inputs=input_box,
label="Click to use example text (then click Add Punctuation button)"
)
# Button click - EXACTLY like your working examples setup
btn.click(
fn=punctuate_text,
inputs=input_box,
outputs=output_box,
api_name="punctuate"
)
gr.HTML("""
<div style="text-align: center; margin-top: 20px; color: #555;">
Powered by Llama-3.2-3B • Hosted on Hugging Face Spaces
</div>
""")
# Launch - EXACTLY like your working version
if __name__ == "__main__":
demo.launch(debug=True)