import torch from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import ( MBartForConditionalGeneration, MBart50Tokenizer, MT5ForConditionalGeneration, T5Tokenizer ) from peft import PeftModel import warnings from dotenv import load_dotenv # ================== Config ================== load_dotenv() warnings.filterwarnings("ignore", category=FutureWarning) app = FastAPI(title="Khmer Summarization API") # Allow CORS for JS frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ================== Models ================== MODELS = { "model1": { "name": "Model 1 - Khmer MBart Summarization (LoRA)", "repo": "sedtha/mBart-50-large_LoRa_kh_sumerize", "type": "mbart_lora", "model": None, "tokenizer": None }, "model2": { "name": "Model 2 - Khmer mT5 Summarization", "repo": "angkor96/khmer-mT5-news-summarization", "type": "mt5", "model": None, "tokenizer": None } } # ================== Load Model ================== def load_model(model_key): model_info = MODELS[model_key] if model_info["model"] is None: print(f"🔄 Loading {model_info['name']} ...") if model_info["type"] == "mbart_lora": tokenizer = MBart50Tokenizer.from_pretrained( model_info["repo"], src_lang="km_KH", tgt_lang="km_KH" ) base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50").to(device) model = PeftModel.from_pretrained(base, model_info["repo"]).merge_and_unload().to(device) elif model_info["type"] == "mt5": tokenizer = T5Tokenizer.from_pretrained(model_info["repo"]) model = MT5ForConditionalGeneration.from_pretrained(model_info["repo"]).to(device) else: raise ValueError("Unknown model type") model.eval() model_info["tokenizer"] = tokenizer model_info["model"] = model print(f"✅ Loaded {model_info['name']}") return model_info["model"], model_info["tokenizer"] # ================== Request Model ================== class SummarizeRequest(BaseModel): text: str models: list[str] = ["model1"] # ================== Summarization ================== def summarize_text(text, model_key): model, tokenizer = load_model(model_key) inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device) with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=300, num_beams=4, early_stopping=True ) return tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() # ================== Endpoints ================== @app.get("/") def root(): return {"message": "✅ Khmer Summarization API (FastAPI) Running!"} @app.get("/models") def list_models(): return {key: {"name": v["name"]} for key, v in MODELS.items()} @app.post("/summarize") async def summarize(req: SummarizeRequest): if not req.text.strip(): return {"error": "⚠️ សូមវាយបញ្ចូលអត្ថបទជាមុន!"} results = {} for key in req.models: if key in MODELS: try: summary = summarize_text(req.text, key) except Exception as e: summary = f"Error: {str(e)}" results[key] = { "name": MODELS[key]["name"], "summary": summary } return {"results": results}