sedtha's picture
Update app.py
c2110ef verified
import torch
import warnings
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from peft import PeftModel
from transformers import (
MBartForConditionalGeneration, MBart50Tokenizer,
MT5ForConditionalGeneration, T5Tokenizer
)
warnings.filterwarnings("ignore")
app = FastAPI(
title="Khmer Summarization API",
description="mBART-LoRA + mT5 in ONE API",
version="1.0.0"
)
# ================= CORS Configuration =================
# Allow all origins for Hugging Face Spaces
origins = [
"https://*.hf.space", # Allow Hugging Face Spaces
"http://localhost",
"http://localhost:3000",
"http://127.0.0.1",
"http://127.0.0.1:3000",
"*" # You can be more restrictive in production
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # Allows all methods (GET, POST, etc.)
allow_headers=["*"], # Allows all headers
)
# ================= Device =================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ================= Models Config =================
MODELS = {
"model1": {
"name": "Khmer mBART + LoRA",
"type": "mbart",
"repo": "sedtha/mBart-50-large_LoRa_kh_sumerize",
"model": None,
"tokenizer": None
},
"model2": {
"name": "Khmer mT5",
"type": "mt5",
"repo": "angkor96/khmer-mT5-news-summarization",
"model": None,
"tokenizer": None
}
}
# ================= Load Model =================
def load_model(key: str):
info = MODELS[key]
if info["model"] is None:
print(f"πŸ”Ή Loading {info['name']}...")
if info["type"] == "mbart":
tokenizer = MBart50Tokenizer.from_pretrained(
info["repo"],
src_lang="km_KH",
tgt_lang="km_KH",
cache_dir="./cache"
)
base_model = MBartForConditionalGeneration.from_pretrained(
"facebook/mbart-large-50",
cache_dir="./cache"
).to(device)
model = PeftModel.from_pretrained(
base_model,
info["repo"],
cache_dir="./cache"
).to(device)
elif info["type"] == "mt5":
tokenizer = T5Tokenizer.from_pretrained(info["repo"], cache_dir="./cache")
model = MT5ForConditionalGeneration.from_pretrained(
info["repo"], cache_dir="./cache"
).to(device)
model.eval()
info["model"] = model
info["tokenizer"] = tokenizer
print(f"βœ… Loaded {info['name']}")
return info["model"], info["tokenizer"]
# ================= Request Schema =================
class SummarizeRequest(BaseModel):
text: str
model: str = "model2"
# ================= API Endpoint =================
@app.post("/summarize")
def summarize(req: SummarizeRequest):
if not req.text.strip():
raise HTTPException(status_code=400, detail="Text is empty")
if req.model not in MODELS:
raise HTTPException(status_code=400, detail="Invalid model")
model, tokenizer = load_model(req.model)
inputs = tokenizer(
req.text,
return_tensors="pt",
truncation=True,
max_length=1024
).to(device)
with torch.no_grad():
summary_ids = model.generate(
**inputs,
do_sample=True,
temperature=0.8,
top_p=0.9,
top_k=50,
max_new_tokens=125,
repetition_penalty=1.2,
no_repeat_ngram_size=3
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# Khmer sentence cleanup
if "αŸ”" in summary:
summary = summary[:summary.rfind("αŸ”") + 1]
return {
"model": MODELS[req.model]["name"],
"summary": summary.strip()
}
# ================= Health Check =================
@app.get("/")
def root():
return {"status": "Khmer Summarization API is running πŸš€"}
# ================= Additional endpoint for testing =================
@app.get("/health")
def health_check():
return {
"status": "healthy",
"device": str(device),
"models_loaded": {
key: info["model"] is not None
for key, info in MODELS.items()
}
}
# ================= Pre-load models on startup (optional) =================
@app.on_event("startup")
async def startup_event():
# Optionally pre-load both models on startup
# This will make first request faster but uses more memory
print("πŸš€ Starting up...")
print(f"Using device: {device}")
# You can choose to pre-load models or load them on first request
# For memory efficiency, we'll load on first request
print("Models will be loaded on first request to save memory")