File size: 3,802 Bytes
e76e570
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import (
    MBartForConditionalGeneration, MBart50Tokenizer,
    MT5ForConditionalGeneration, T5Tokenizer
)
from peft import PeftModel
import warnings
from dotenv import load_dotenv

# ================== Config ==================
load_dotenv()
warnings.filterwarnings("ignore", category=FutureWarning)

app = FastAPI(title="Khmer Summarization API")

# Allow CORS for JS frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ================== Models ==================
MODELS = {
    "model1": {
        "name": "Model 1 - Khmer MBart Summarization (LoRA)",
        "repo": "sedtha/mBart-50-large_LoRa_kh_sumerize",
        "type": "mbart_lora",
        "model": None,
        "tokenizer": None
    },
    "model2": {
        "name": "Model 2 - Khmer mT5 Summarization",
        "repo": "angkor96/khmer-mT5-news-summarization",
        "type": "mt5",
        "model": None,
        "tokenizer": None
    }
}

# ================== Load Model ==================
def load_model(model_key):
    model_info = MODELS[model_key]

    if model_info["model"] is None:
        print(f"πŸ”„ Loading {model_info['name']} ...")

        if model_info["type"] == "mbart_lora":
            tokenizer = MBart50Tokenizer.from_pretrained(
                model_info["repo"], src_lang="km_KH", tgt_lang="km_KH"
            )
            base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50").to(device)
            model = PeftModel.from_pretrained(base, model_info["repo"]).merge_and_unload().to(device)

        elif model_info["type"] == "mt5":
            tokenizer = T5Tokenizer.from_pretrained(model_info["repo"])
            model = MT5ForConditionalGeneration.from_pretrained(model_info["repo"]).to(device)

        else:
            raise ValueError("Unknown model type")

        model.eval()
        model_info["tokenizer"] = tokenizer
        model_info["model"] = model

        print(f"βœ… Loaded {model_info['name']}")

    return model_info["model"], model_info["tokenizer"]


# ================== Request Model ==================
class SummarizeRequest(BaseModel):
    text: str
    models: list[str] = ["model1"]


# ================== Summarization ==================
def summarize_text(text, model_key):
    model, tokenizer = load_model(model_key)

    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=300,
            num_beams=4,
            early_stopping=True
        )

    return tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()


# ================== Endpoints ==================
@app.get("/")
def root():
    return {"message": "βœ… Khmer Summarization API (FastAPI) Running!"}

@app.get("/models")
def list_models():
    return {key: {"name": v["name"]} for key, v in MODELS.items()}

@app.post("/summarize")
async def summarize(req: SummarizeRequest):
    if not req.text.strip():
        return {"error": "⚠️ αžŸαžΌαž˜αžœαžΆαž™αž”αž‰αŸ’αž…αžΌαž›αž’αžαŸ’αžαž”αž‘αž‡αžΆαž˜αž»αž“!"}

    results = {}
    for key in req.models:
        if key in MODELS:
            try:
                summary = summarize_text(req.text, key)
            except Exception as e:
                summary = f"Error: {str(e)}"
            results[key] = {
                "name": MODELS[key]["name"],
                "summary": summary
            }

    return {"results": results}