FaiziRBLX commited on
Commit
3d4765d
Β·
verified Β·
1 Parent(s): 7041c0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -130
app.py CHANGED
@@ -5,32 +5,43 @@ import os
5
  import logging
6
  from collections import defaultdict
7
  from transformers import AutoTokenizer
8
- from fastapi import Request, HTTPException, Depends
9
- from fastapi.middleware.cors import CORSMiddleware
10
- from slowapi import Limiter, _rate_limit_exceeded_handler
11
- from slowapi.util import get_remote_address
12
- from slowapi.errors import RateLimitExceeded
13
- from pydantic import BaseModel, Field
14
- from fastapi import FastAPI
15
  from fastapi.responses import JSONResponse
 
16
  import gradio as gr
17
  from best import ModelConfig, IndonesianLLM, generate_text, _extract_thinking
18
 
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- # ── Load Model ───────────────────────────────────────────
23
- device = torch.device('cpu')
 
24
 
 
25
  logger.info(f"model.pt ada: {os.path.exists('model.pt')}")
 
 
 
26
 
 
 
27
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
28
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
 
29
 
 
 
30
  checkpoint = torch.load("model.pt", map_location='cpu', weights_only=False)
31
- config = checkpoint['config']
32
- model = IndonesianLLM(config)
 
 
 
 
33
 
 
34
  state_dict = checkpoint['model_state_dict']
35
  for k in list(state_dict.keys()):
36
  if state_dict[k].dtype == torch.float16:
@@ -39,172 +50,204 @@ for k in list(state_dict.keys()):
39
  model.load_state_dict(state_dict)
40
  del checkpoint, state_dict
41
  gc.collect()
 
42
  model.eval()
 
43
  logger.info("Model siap!")
44
 
45
- # ── Rate limiter ─────────────────────────────────────────
46
- limiter = Limiter(key_func=get_remote_address)
47
- ip_request_count: dict = defaultdict(list)
48
- ip_banned_until: dict = {}
49
-
50
- API_KEYS = {"kunci-rahasia-kamu-123"} # ← ganti!
51
-
52
- # ── Gradio UI ────────────────────────────────────────────
53
- def gradio_chat(message, history):
54
- prompt = f"{message} <cot>"
55
- full = generate_text(
56
- model=model, tokenizer=tokenizer, prompt=prompt,
57
- max_new_tokens=200, temperature=0.7,
58
- top_k=50, top_p=0.9, device=device
59
- )
60
- raw = full[len(prompt):].strip()
61
- _, answer = _extract_thinking(raw)
62
- return answer if answer else "Maaf, saya tidak mengerti."
63
-
64
- demo = gr.ChatInterface(
65
- fn=gradio_chat,
66
- title="Indonesian LLM",
67
- description="Chat dengan model bahasa Indonesia"
68
  )
69
 
70
- # ── Tambah API route ke Gradio's FastAPI ─────────────────
71
- app = demo.app # Gradio expose FastAPI internal di sini
72
-
73
- app.state.limiter = limiter
74
- app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
75
  app.add_middleware(
76
  CORSMiddleware,
77
  allow_origins=["*"],
78
- allow_methods=["POST", "GET"],
79
  allow_headers=["*"],
80
  )
81
 
 
82
  @app.middleware("http")
83
  async def ddos_protection(request: Request, call_next):
84
- ip = get_remote_address(request)
85
  now = time.time()
 
86
  if ip in ip_banned_until:
87
  if now < ip_banned_until[ip]:
88
- raise HTTPException(429, f"Banned. Coba lagi dalam {int(ip_banned_until[ip]-now)}s.")
89
- del ip_banned_until[ip]
90
- ip_request_count[ip] = []
 
 
 
 
 
 
91
  ip_request_count[ip].append(now)
92
- ip_request_count[ip] = [t for t in ip_request_count[ip] if now - t < 60]
93
- if len(ip_request_count[ip]) > 100:
94
- ip_banned_until[ip] = now + 3600
95
- raise HTTPException(429, "Terlalu banyak request. Banned 1 jam.")
96
- return await call_next(request)
97
 
98
- class ChatRequest(BaseModel):
99
- message: str = Field(..., min_length=1, max_length=500)
100
- max_tokens: int = Field(default=200, ge=10, le=500)
101
- temperature: float = Field(default=0.7, ge=0.1, le=1.5)
102
- show_thinking: bool = False
 
 
103
 
104
- class ChatResponse(BaseModel):
105
- answer: str
106
- thinking: str | None = None
107
- processing_time_ms: int
108
 
109
- def verify_api_key(request: Request):
 
 
 
 
110
  key = request.headers.get("X-API-Key")
111
  if not key or key not in API_KEYS:
112
- raise HTTPException(401, "API key tidak valid.")
113
- return key
114
 
115
  @app.get("/api/health")
116
  def health():
117
- return {"status": "ok", "device": str(device)}
118
-
119
- @app.post("/api/chat", response_model=ChatResponse)
120
- @limiter.limit("10/minute")
121
- @limiter.limit("50/hour")
122
- async def api_chat(
123
- req: ChatRequest,
124
- request: Request,
125
- _key: str = Depends(verify_api_key)
126
- ):
127
- start = time.time()
128
- prompt = f"{req.message} <cot>"
129
- full = generate_text(
130
- model=model, tokenizer=tokenizer, prompt=prompt,
131
- max_new_tokens=req.max_tokens, temperature=req.temperature,
132
- top_k=50, top_p=0.9, device=device
133
- )
134
- raw = full[len(prompt):].strip()
135
- thinking, answer = _extract_thinking(raw)
136
- return ChatResponse(
137
- answer=answer if answer else "Maaf, saya tidak mengerti.",
138
- thinking=thinking if req.show_thinking else None,
139
- processing_time_ms=int((time.time() - start) * 1000)
140
- )
141
-
142
- # Ganti bagian bawah app.py β€” dari "Tambah API route" sampai akhir
143
-
144
- # ── Build Gradio dulu ─────────────────────────────────────
145
- def gradio_chat(message, history):
146
- prompt = f"{message} <cot>"
147
- full = generate_text(
148
- model=model, tokenizer=tokenizer, prompt=prompt,
149
- max_new_tokens=200, temperature=0.7,
150
- top_k=50, top_p=0.9, device=device
151
- )
152
- raw = full[len(prompt):].strip()
153
- _, answer = _extract_thinking(raw)
154
- return answer if answer else "Maaf, saya tidak mengerti."
155
-
156
- demo = gr.ChatInterface(
157
- fn=gradio_chat,
158
- title="Indonesian LLM",
159
- description="Chat dengan model bahasa Indonesia"
160
- )
161
 
162
- # Tambah route langsung ke demo.app
163
- @demo.app.get("/api/health")
164
- def health():
165
- return {"status": "ok", "device": str(device)}
166
-
167
- @demo.app.post("/api/chat")
168
  async def api_chat(request: Request):
169
  # Cek API key
170
- key = request.headers.get("X-API-Key")
171
- if not key or key not in API_KEYS:
172
- return JSONResponse(status_code=401, content={"error": "API key tidak valid."})
 
 
173
 
174
- # Parse body manual (hindari Pydantic issue)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  try:
176
  body = await request.json()
177
- message = body.get("message", "").strip()
178
  max_tokens = int(body.get("max_tokens", 200))
179
  temperature = float(body.get("temperature", 0.7))
180
  show_think = bool(body.get("show_thinking", False))
181
  except Exception:
182
- return JSONResponse(status_code=400, content={"error": "Request tidak valid."})
 
 
 
183
 
184
- if not message or len(message) > 500:
185
- return JSONResponse(status_code=400, content={"error": "Pesan kosong atau terlalu panjang."})
 
 
 
 
 
 
 
186
 
187
  # Generate
188
  try:
189
- start = time.time()
190
- prompt = f"{message} <cot>"
191
- full = generate_text(
192
- model=model, tokenizer=tokenizer, prompt=prompt,
193
- max_new_tokens=max_tokens, temperature=temperature,
194
- top_k=50, top_p=0.9, device=device
 
 
 
 
 
195
  )
196
  raw = full[len(prompt):].strip()
197
  thinking, answer = _extract_thinking(raw)
198
- elapsed = int((time.time() - start) * 1000)
 
 
199
 
200
  return JSONResponse(content={
201
  "answer": answer if answer else "Maaf, saya tidak mengerti.",
202
  "thinking": thinking if show_think else None,
203
- "processing_time_ms": elapsed
204
  })
 
205
  except Exception as e:
206
  logger.error(f"Generate error: {e}")
207
- return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- # ── Launch ───────────────────────────────────────────────
210
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
 
 
5
  import logging
6
  from collections import defaultdict
7
  from transformers import AutoTokenizer
8
+ from fastapi import FastAPI, Request
 
 
 
 
 
 
9
  from fastapi.responses import JSONResponse
10
+ from fastapi.middleware.cors import CORSMiddleware
11
  import gradio as gr
12
  from best import ModelConfig, IndonesianLLM, generate_text, _extract_thinking
13
 
14
+ # ── Logging ───────────────────────────────────────────────
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ # ── Device ────────────────────────────────────────────────
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ logger.info(f"Device: {device}")
21
 
22
+ # ── Cek model file ────────────────────────────────────────
23
  logger.info(f"model.pt ada: {os.path.exists('model.pt')}")
24
+ if not os.path.exists('model.pt'):
25
+ raise FileNotFoundError("model.pt tidak ditemukan! Upload dulu ke Space.")
26
+ logger.info(f"model.pt size: {os.path.getsize('model.pt') / 1e6:.1f} MB")
27
 
28
+ # ── Load tokenizer ────────────────────────────────────────
29
+ logger.info("Loading tokenizer...")
30
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
31
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
32
+ logger.info("Tokenizer OK")
33
 
34
+ # ── Load model ────────────────────────────────────────────
35
+ logger.info("Loading checkpoint...")
36
  checkpoint = torch.load("model.pt", map_location='cpu', weights_only=False)
37
+ logger.info(f"Checkpoint keys: {list(checkpoint.keys())}")
38
+
39
+ logger.info("Building model...")
40
+ config = checkpoint['config']
41
+ model = IndonesianLLM(config)
42
+ logger.info(f"Model params: {model.count_parameters():,}")
43
 
44
+ logger.info("Loading weights...")
45
  state_dict = checkpoint['model_state_dict']
46
  for k in list(state_dict.keys()):
47
  if state_dict[k].dtype == torch.float16:
 
50
  model.load_state_dict(state_dict)
51
  del checkpoint, state_dict
52
  gc.collect()
53
+
54
  model.eval()
55
+ model.to(device)
56
  logger.info("Model siap!")
57
 
58
+ # ── Config ────────────────────────────────────────────────
59
+ API_KEYS = {"kunci-rahasia-kamu-123"} # ← GANTI!
60
+ ip_request_count = defaultdict(list)
61
+ ip_banned_until = {}
62
+ BLACKLIST_THRESHOLD = 100
63
+ BLACKLIST_WINDOW = 60
64
+ BLACKLIST_DURATION = 3600
65
+
66
+ # ═══════════════════════════════════════════════════════════
67
+ # 1. FastAPI (induk)
68
+ # ═══════════════════════════════════════════════════════════
69
+ app = FastAPI(
70
+ title="Indonesian LLM API",
71
+ description="API untuk model bahasa Indonesia dengan Chain-of-Thought",
72
+ version="1.0.0"
 
 
 
 
 
 
 
 
73
  )
74
 
75
+ # ── CORS ──────────────────────────────────────────────────
 
 
 
 
76
  app.add_middleware(
77
  CORSMiddleware,
78
  allow_origins=["*"],
79
+ allow_methods=["*"],
80
  allow_headers=["*"],
81
  )
82
 
83
+ # ── DDoS protection ───────────────────────────────────────
84
  @app.middleware("http")
85
  async def ddos_protection(request: Request, call_next):
86
+ ip = request.client.host if request.client else "unknown"
87
  now = time.time()
88
+
89
  if ip in ip_banned_until:
90
  if now < ip_banned_until[ip]:
91
+ remaining = int(ip_banned_until[ip] - now)
92
+ return JSONResponse(
93
+ status_code=429,
94
+ content={"error": f"IP dibanned. Coba lagi dalam {remaining} detik."}
95
+ )
96
+ else:
97
+ del ip_banned_until[ip]
98
+ ip_request_count[ip] = []
99
+
100
  ip_request_count[ip].append(now)
101
+ ip_request_count[ip] = [t for t in ip_request_count[ip] if now - t < BLACKLIST_WINDOW]
 
 
 
 
102
 
103
+ if len(ip_request_count[ip]) > BLACKLIST_THRESHOLD:
104
+ ip_banned_until[ip] = now + BLACKLIST_DURATION
105
+ ip_request_count[ip] = []
106
+ return JSONResponse(
107
+ status_code=429,
108
+ content={"error": f"Terlalu banyak request. IP dibanned selama {BLACKLIST_DURATION // 60} menit."}
109
+ )
110
 
111
+ return await call_next(request)
 
 
 
112
 
113
+ # ═══════════════════════════════════════════════════════════
114
+ # 2. API Routes
115
+ # ═══════════════════════════════════════════════════════════
116
+
117
+ def check_api_key(request: Request):
118
  key = request.headers.get("X-API-Key")
119
  if not key or key not in API_KEYS:
120
+ return False
121
+ return True
122
 
123
  @app.get("/api/health")
124
  def health():
125
+ return {
126
+ "status": "ok",
127
+ "device": str(device),
128
+ "model_params": model.count_parameters()
129
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ @app.post("/api/chat")
 
 
 
 
 
132
  async def api_chat(request: Request):
133
  # Cek API key
134
+ if not check_api_key(request):
135
+ return JSONResponse(
136
+ status_code=401,
137
+ content={"error": "API key tidak valid. Tambahkan header X-API-Key."}
138
+ )
139
 
140
+ # Rate limit per endpoint (10 req/menit per IP)
141
+ ip = request.client.host if request.client else "unknown"
142
+ now = time.time()
143
+ endpoint_key = f"{ip}_chat"
144
+ if endpoint_key not in ip_request_count:
145
+ ip_request_count[endpoint_key] = []
146
+ ip_request_count[endpoint_key] = [
147
+ t for t in ip_request_count[endpoint_key] if now - t < 60
148
+ ]
149
+ if len(ip_request_count[endpoint_key]) >= 10:
150
+ return JSONResponse(
151
+ status_code=429,
152
+ content={"error": "Rate limit: maksimal 10 request per menit."}
153
+ )
154
+ ip_request_count[endpoint_key].append(now)
155
+
156
+ # Parse request body
157
  try:
158
  body = await request.json()
159
+ message = str(body.get("message", "")).strip()
160
  max_tokens = int(body.get("max_tokens", 200))
161
  temperature = float(body.get("temperature", 0.7))
162
  show_think = bool(body.get("show_thinking", False))
163
  except Exception:
164
+ return JSONResponse(
165
+ status_code=400,
166
+ content={"error": "Request body tidak valid. Gunakan JSON."}
167
+ )
168
 
169
+ # Validasi input
170
+ if not message:
171
+ return JSONResponse(status_code=400, content={"error": "Pesan tidak boleh kosong."})
172
+ if len(message) > 500:
173
+ return JSONResponse(status_code=400, content={"error": "Pesan terlalu panjang. Maksimal 500 karakter."})
174
+ if not (10 <= max_tokens <= 500):
175
+ return JSONResponse(status_code=400, content={"error": "max_tokens harus antara 10 dan 500."})
176
+ if not (0.1 <= temperature <= 1.5):
177
+ return JSONResponse(status_code=400, content={"error": "temperature harus antara 0.1 dan 1.5."})
178
 
179
  # Generate
180
  try:
181
+ start = time.time()
182
+ prompt = f"{message} <cot>"
183
+ full = generate_text(
184
+ model=model,
185
+ tokenizer=tokenizer,
186
+ prompt=prompt,
187
+ max_new_tokens=max_tokens,
188
+ temperature=temperature,
189
+ top_k=50,
190
+ top_p=0.9,
191
+ device=device
192
  )
193
  raw = full[len(prompt):].strip()
194
  thinking, answer = _extract_thinking(raw)
195
+ elapsed_ms = int((time.time() - start) * 1000)
196
+
197
+ logger.info(f"[{ip}] '{message[:40]}' β†’ {elapsed_ms}ms")
198
 
199
  return JSONResponse(content={
200
  "answer": answer if answer else "Maaf, saya tidak mengerti.",
201
  "thinking": thinking if show_think else None,
202
+ "processing_time_ms": elapsed_ms
203
  })
204
+
205
  except Exception as e:
206
  logger.error(f"Generate error: {e}")
207
+ return JSONResponse(
208
+ status_code=500,
209
+ content={"error": f"Gagal generate: {str(e)}"}
210
+ )
211
+
212
+ # ═══════════════════════════════════════════════════════════
213
+ # 3. Gradio UI
214
+ # ═══════════════════════════════════════════════════════════
215
+ def gradio_chat(message, history):
216
+ if not message.strip():
217
+ return "Silakan ketik pesan."
218
+ try:
219
+ prompt = f"{message} <cot>"
220
+ full = generate_text(
221
+ model=model,
222
+ tokenizer=tokenizer,
223
+ prompt=prompt,
224
+ max_new_tokens=200,
225
+ temperature=0.7,
226
+ top_k=50,
227
+ top_p=0.9,
228
+ device=device
229
+ )
230
+ raw = full[len(prompt):].strip()
231
+ _, answer = _extract_thinking(raw)
232
+ return answer if answer else "Maaf, saya tidak mengerti."
233
+ except Exception as e:
234
+ logger.error(f"Gradio error: {e}")
235
+ return f"Error: {str(e)}"
236
+
237
+ gradio_ui = gr.ChatInterface(
238
+ fn=gradio_chat,
239
+ title="Indonesian LLM",
240
+ description="Model bahasa Indonesia dengan kemampuan Chain-of-Thought reasoning. Juga tersedia sebagai API di `/api/chat`.",
241
+ examples=[
242
+ ["Halo, apa kabar?"],
243
+ ["Jelaskan cara kerja internet"],
244
+ ["Berapa hasil dari 25 dikali 4?"],
245
+ ["Apa ibu kota Indonesia?"],
246
+ ],
247
+ theme=gr.themes.Soft()
248
+ )
249
 
250
+ # ═══════════════════════════════════════════════════════════
251
+ # 4. Mount Gradio ke FastAPI β€” FastAPI sebagai induk
252
+ # ═══════════════════════════════════════════════════════════
253
+ demo = gr.mount_gradio_app(app, gradio_ui, path="/")