Spaces:
Build error
Build error
| import torch | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import ( | |
| MBartForConditionalGeneration, MBart50Tokenizer, | |
| MT5ForConditionalGeneration, T5Tokenizer | |
| ) | |
| from peft import PeftModel | |
| import warnings | |
| from dotenv import load_dotenv | |
| # ================== Config ================== | |
| load_dotenv() | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| app = FastAPI(title="Khmer Summarization API") | |
| # Allow CORS for JS frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ================== Models ================== | |
| MODELS = { | |
| "model1": { | |
| "name": "Model 1 - Khmer MBart Summarization (LoRA)", | |
| "repo": "sedtha/mBart-50-large_LoRa_kh_sumerize", | |
| "type": "mbart_lora", | |
| "model": None, | |
| "tokenizer": None | |
| }, | |
| "model2": { | |
| "name": "Model 2 - Khmer mT5 Summarization", | |
| "repo": "angkor96/khmer-mT5-news-summarization", | |
| "type": "mt5", | |
| "model": None, | |
| "tokenizer": None | |
| } | |
| } | |
| # ================== Load Model ================== | |
| def load_model(model_key): | |
| model_info = MODELS[model_key] | |
| if model_info["model"] is None: | |
| print(f"π Loading {model_info['name']} ...") | |
| if model_info["type"] == "mbart_lora": | |
| tokenizer = MBart50Tokenizer.from_pretrained( | |
| model_info["repo"], src_lang="km_KH", tgt_lang="km_KH" | |
| ) | |
| base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50").to(device) | |
| model = PeftModel.from_pretrained(base, model_info["repo"]).merge_and_unload().to(device) | |
| elif model_info["type"] == "mt5": | |
| tokenizer = T5Tokenizer.from_pretrained(model_info["repo"]) | |
| model = MT5ForConditionalGeneration.from_pretrained(model_info["repo"]).to(device) | |
| else: | |
| raise ValueError("Unknown model type") | |
| model.eval() | |
| model_info["tokenizer"] = tokenizer | |
| model_info["model"] = model | |
| print(f"β Loaded {model_info['name']}") | |
| return model_info["model"], model_info["tokenizer"] | |
| # ================== Request Model ================== | |
| class SummarizeRequest(BaseModel): | |
| text: str | |
| models: list[str] = ["model1"] | |
| # ================== Summarization ================== | |
| def summarize_text(text, model_key): | |
| model, tokenizer = load_model(model_key) | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| return tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| # ================== Endpoints ================== | |
| def root(): | |
| return {"message": "β Khmer Summarization API (FastAPI) Running!"} | |
| def list_models(): | |
| return {key: {"name": v["name"]} for key, v in MODELS.items()} | |
| async def summarize(req: SummarizeRequest): | |
| if not req.text.strip(): | |
| return {"error": "β οΈ ααΌαααΆααααα αΌαα’αααααααΆαα»α!"} | |
| results = {} | |
| for key in req.models: | |
| if key in MODELS: | |
| try: | |
| summary = summarize_text(req.text, key) | |
| except Exception as e: | |
| summary = f"Error: {str(e)}" | |
| results[key] = { | |
| "name": MODELS[key]["name"], | |
| "summary": summary | |
| } | |
| return {"results": results} | |