sedtha commited on
Commit
c2110ef
Β·
verified Β·
1 Parent(s): 6092add

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -40
app.py CHANGED
@@ -1,70 +1,173 @@
1
- from fastapi import FastAPI
 
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
5
- import torch
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # --------------------
8
- # App initialization
9
- # --------------------
10
- app = FastAPI(title="Khmer Summarization API")
 
 
 
 
 
 
11
 
12
  app.add_middleware(
13
  CORSMiddleware,
14
- allow_origins=["*"], # allow frontend from anywhere
15
- allow_methods=["*"],
16
- allow_headers=["*"],
 
17
  )
18
 
19
- # --------------------
20
- # Model loading
21
- # --------------------
22
- MODEL_NAME = "sedtha/mBart-50-large_LoRa_kh_sumerize"
23
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- print("Loading tokenizer...")
26
- tokenizer = MBart50TokenizerFast.from_pretrained(MODEL_NAME)
 
 
 
27
 
28
- print("Loading model...")
29
- model = MBartForConditionalGeneration.from_pretrained(MODEL_NAME)
30
- model.to(DEVICE)
31
- model.eval()
 
32
 
33
- print("Model loaded successfully!")
 
 
34
 
35
- # --------------------
36
- # Request schema
37
- # --------------------
 
 
38
  class SummarizeRequest(BaseModel):
39
  text: str
40
- max_length: int = 150
41
- min_length: int = 40
42
 
43
- # --------------------
44
- # API endpoint
45
- # --------------------
46
  @app.post("/summarize")
47
  def summarize(req: SummarizeRequest):
 
 
 
 
 
 
 
 
48
  inputs = tokenizer(
49
  req.text,
50
  return_tensors="pt",
51
  truncation=True,
52
  max_length=1024
53
- ).to(DEVICE)
54
 
55
  with torch.no_grad():
56
- output_ids = model.generate(
57
  **inputs,
58
- max_length=req.max_length,
59
- min_length=req.min_length,
60
- num_beams=4
 
 
 
 
61
  )
62
 
63
- summary = tokenizer.decode(
64
- output_ids[0],
65
- skip_special_tokens=True
66
- )
 
67
 
68
  return {
69
- "summary": summary
 
70
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from fastapi import FastAPI, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
+ from peft import PeftModel
7
+ from transformers import (
8
+ MBartForConditionalGeneration, MBart50Tokenizer,
9
+ MT5ForConditionalGeneration, T5Tokenizer
10
+ )
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+ app = FastAPI(
15
+ title="Khmer Summarization API",
16
+ description="mBART-LoRA + mT5 in ONE API",
17
+ version="1.0.0"
18
+ )
19
 
20
+ # ================= CORS Configuration =================
21
+ # Allow all origins for Hugging Face Spaces
22
+ origins = [
23
+ "https://*.hf.space", # Allow Hugging Face Spaces
24
+ "http://localhost",
25
+ "http://localhost:3000",
26
+ "http://127.0.0.1",
27
+ "http://127.0.0.1:3000",
28
+ "*" # You can be more restrictive in production
29
+ ]
30
 
31
  app.add_middleware(
32
  CORSMiddleware,
33
+ allow_origins=origins,
34
+ allow_credentials=True,
35
+ allow_methods=["*"], # Allows all methods (GET, POST, etc.)
36
+ allow_headers=["*"], # Allows all headers
37
  )
38
 
39
+ # ================= Device =================
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+
42
+ # ================= Models Config =================
43
+ MODELS = {
44
+ "model1": {
45
+ "name": "Khmer mBART + LoRA",
46
+ "type": "mbart",
47
+ "repo": "sedtha/mBart-50-large_LoRa_kh_sumerize",
48
+ "model": None,
49
+ "tokenizer": None
50
+ },
51
+ "model2": {
52
+ "name": "Khmer mT5",
53
+ "type": "mt5",
54
+ "repo": "angkor96/khmer-mT5-news-summarization",
55
+ "model": None,
56
+ "tokenizer": None
57
+ }
58
+ }
59
+
60
+ # ================= Load Model =================
61
+ def load_model(key: str):
62
+ info = MODELS[key]
63
+
64
+ if info["model"] is None:
65
+ print(f"πŸ”Ή Loading {info['name']}...")
66
+
67
+ if info["type"] == "mbart":
68
+ tokenizer = MBart50Tokenizer.from_pretrained(
69
+ info["repo"],
70
+ src_lang="km_KH",
71
+ tgt_lang="km_KH",
72
+ cache_dir="./cache"
73
+ )
74
+
75
+ base_model = MBartForConditionalGeneration.from_pretrained(
76
+ "facebook/mbart-large-50",
77
+ cache_dir="./cache"
78
+ ).to(device)
79
 
80
+ model = PeftModel.from_pretrained(
81
+ base_model,
82
+ info["repo"],
83
+ cache_dir="./cache"
84
+ ).to(device)
85
 
86
+ elif info["type"] == "mt5":
87
+ tokenizer = T5Tokenizer.from_pretrained(info["repo"], cache_dir="./cache")
88
+ model = MT5ForConditionalGeneration.from_pretrained(
89
+ info["repo"], cache_dir="./cache"
90
+ ).to(device)
91
 
92
+ model.eval()
93
+ info["model"] = model
94
+ info["tokenizer"] = tokenizer
95
 
96
+ print(f"βœ… Loaded {info['name']}")
97
+
98
+ return info["model"], info["tokenizer"]
99
+
100
+ # ================= Request Schema =================
101
  class SummarizeRequest(BaseModel):
102
  text: str
103
+ model: str = "model2"
 
104
 
105
+ # ================= API Endpoint =================
 
 
106
  @app.post("/summarize")
107
  def summarize(req: SummarizeRequest):
108
+ if not req.text.strip():
109
+ raise HTTPException(status_code=400, detail="Text is empty")
110
+
111
+ if req.model not in MODELS:
112
+ raise HTTPException(status_code=400, detail="Invalid model")
113
+
114
+ model, tokenizer = load_model(req.model)
115
+
116
  inputs = tokenizer(
117
  req.text,
118
  return_tensors="pt",
119
  truncation=True,
120
  max_length=1024
121
+ ).to(device)
122
 
123
  with torch.no_grad():
124
+ summary_ids = model.generate(
125
  **inputs,
126
+ do_sample=True,
127
+ temperature=0.8,
128
+ top_p=0.9,
129
+ top_k=50,
130
+ max_new_tokens=125,
131
+ repetition_penalty=1.2,
132
+ no_repeat_ngram_size=3
133
  )
134
 
135
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
136
+
137
+ # Khmer sentence cleanup
138
+ if "αŸ”" in summary:
139
+ summary = summary[:summary.rfind("αŸ”") + 1]
140
 
141
  return {
142
+ "model": MODELS[req.model]["name"],
143
+ "summary": summary.strip()
144
  }
145
+
146
+ # ================= Health Check =================
147
+ @app.get("/")
148
+ def root():
149
+ return {"status": "Khmer Summarization API is running πŸš€"}
150
+
151
+ # ================= Additional endpoint for testing =================
152
+ @app.get("/health")
153
+ def health_check():
154
+ return {
155
+ "status": "healthy",
156
+ "device": str(device),
157
+ "models_loaded": {
158
+ key: info["model"] is not None
159
+ for key, info in MODELS.items()
160
+ }
161
+ }
162
+
163
+ # ================= Pre-load models on startup (optional) =================
164
+ @app.on_event("startup")
165
+ async def startup_event():
166
+ # Optionally pre-load both models on startup
167
+ # This will make first request faster but uses more memory
168
+ print("πŸš€ Starting up...")
169
+ print(f"Using device: {device}")
170
+
171
+ # You can choose to pre-load models or load them on first request
172
+ # For memory efficiency, we'll load on first request
173
+ print("Models will be loaded on first request to save memory")