|
|
import torch |
|
|
import warnings |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from peft import PeftModel |
|
|
from transformers import ( |
|
|
MBartForConditionalGeneration, MBart50Tokenizer, |
|
|
MT5ForConditionalGeneration, T5Tokenizer |
|
|
) |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
app = FastAPI( |
|
|
title="Khmer Summarization API", |
|
|
description="mBART-LoRA + mT5 in ONE API", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
origins = [ |
|
|
"https://*.hf.space", |
|
|
"http://localhost", |
|
|
"http://localhost:3000", |
|
|
"http://127.0.0.1", |
|
|
"http://127.0.0.1:3000", |
|
|
"*" |
|
|
] |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"model1": { |
|
|
"name": "Khmer mBART + LoRA", |
|
|
"type": "mbart", |
|
|
"repo": "sedtha/mBart-50-large_LoRa_kh_sumerize", |
|
|
"model": None, |
|
|
"tokenizer": None |
|
|
}, |
|
|
"model2": { |
|
|
"name": "Khmer mT5", |
|
|
"type": "mt5", |
|
|
"repo": "angkor96/khmer-mT5-news-summarization", |
|
|
"model": None, |
|
|
"tokenizer": None |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def load_model(key: str): |
|
|
info = MODELS[key] |
|
|
|
|
|
if info["model"] is None: |
|
|
print(f"πΉ Loading {info['name']}...") |
|
|
|
|
|
if info["type"] == "mbart": |
|
|
tokenizer = MBart50Tokenizer.from_pretrained( |
|
|
info["repo"], |
|
|
src_lang="km_KH", |
|
|
tgt_lang="km_KH", |
|
|
cache_dir="./cache" |
|
|
) |
|
|
|
|
|
base_model = MBartForConditionalGeneration.from_pretrained( |
|
|
"facebook/mbart-large-50", |
|
|
cache_dir="./cache" |
|
|
).to(device) |
|
|
|
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
info["repo"], |
|
|
cache_dir="./cache" |
|
|
).to(device) |
|
|
|
|
|
elif info["type"] == "mt5": |
|
|
tokenizer = T5Tokenizer.from_pretrained(info["repo"], cache_dir="./cache") |
|
|
model = MT5ForConditionalGeneration.from_pretrained( |
|
|
info["repo"], cache_dir="./cache" |
|
|
).to(device) |
|
|
|
|
|
model.eval() |
|
|
info["model"] = model |
|
|
info["tokenizer"] = tokenizer |
|
|
|
|
|
print(f"β
Loaded {info['name']}") |
|
|
|
|
|
return info["model"], info["tokenizer"] |
|
|
|
|
|
|
|
|
class SummarizeRequest(BaseModel): |
|
|
text: str |
|
|
model: str = "model2" |
|
|
|
|
|
|
|
|
@app.post("/summarize") |
|
|
def summarize(req: SummarizeRequest): |
|
|
if not req.text.strip(): |
|
|
raise HTTPException(status_code=400, detail="Text is empty") |
|
|
|
|
|
if req.model not in MODELS: |
|
|
raise HTTPException(status_code=400, detail="Invalid model") |
|
|
|
|
|
model, tokenizer = load_model(req.model) |
|
|
|
|
|
inputs = tokenizer( |
|
|
req.text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=1024 |
|
|
).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
summary_ids = model.generate( |
|
|
**inputs, |
|
|
do_sample=True, |
|
|
temperature=0.8, |
|
|
top_p=0.9, |
|
|
top_k=50, |
|
|
max_new_tokens=125, |
|
|
repetition_penalty=1.2, |
|
|
no_repeat_ngram_size=3 |
|
|
) |
|
|
|
|
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "α" in summary: |
|
|
summary = summary[:summary.rfind("α") + 1] |
|
|
|
|
|
return { |
|
|
"model": MODELS[req.model]["name"], |
|
|
"summary": summary.strip() |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"status": "Khmer Summarization API is running π"} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health_check(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"device": str(device), |
|
|
"models_loaded": { |
|
|
key: info["model"] is not None |
|
|
for key, info in MODELS.items() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
|
|
|
|
|
|
print("π Starting up...") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
print("Models will be loaded on first request to save memory") |