FaiziRBLX commited on
Commit
166c4d3
Β·
verified Β·
1 Parent(s): 4bc037c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -28
app.py CHANGED
@@ -1,22 +1,27 @@
1
  import torch
2
- import gradio as gr
 
 
3
  from transformers import AutoTokenizer
4
- from best import ModelConfig, IndonesianLLM
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Load tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
8
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
9
 
10
- # Load checkpoint (strukturnya: {"model_state_dict": ..., "config": ..., dst})
11
- checkpoint = torch.load("model.pt", map_location=torch.device('cpu'), weights_only=False)
12
-
13
- # Ambil config dari checkpoint (bukan ModelConfig default!)
14
  config = checkpoint['config']
15
-
16
- # Bangun kerangka model sesuai config yang tersimpan
17
  model = IndonesianLLM(config)
18
 
19
- # Ambil bobot, konversi fp16 β†’ fp32 jika perlu
20
  state_dict = checkpoint['model_state_dict']
21
  if checkpoint.get('dtype') == 'fp16':
22
  state_dict = {k: v.float() if v.dtype == torch.float16 else v
@@ -24,29 +29,128 @@ if checkpoint.get('dtype') == 'fp16':
24
 
25
  model.load_state_dict(state_dict)
26
  model.eval()
27
-
28
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
  model.to(device)
30
 
31
- # Fungsi inference
32
- def predict(teks_input):
33
- from best import generate_text, _extract_thinking
34
- prompt = f"{teks_input} <cot>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  full = generate_text(
36
  model=model, tokenizer=tokenizer, prompt=prompt,
37
- max_new_tokens=200, temperature=0.7,
38
  top_k=50, top_p=0.9, device=device
39
  )
40
  raw = full[len(prompt):].strip()
41
- _, answer = _extract_thinking(raw)
42
- return answer if answer else "Maaf, saya tidak mengerti."
43
-
44
- # Gradio UI
45
- iface = gr.Interface(
46
- fn=predict,
47
- inputs=gr.Textbox(lines=2, placeholder="Ketik pesan di sini..."),
48
- outputs="text",
49
- title="Indonesian LLM API"
50
- )
51
 
52
- iface.launch()
 
 
 
 
 
1
  import torch
2
+ import time
3
+ import hashlib
4
+ from collections import defaultdict
5
  from transformers import AutoTokenizer
6
+ from fastapi import FastAPI, Request, HTTPException, Depends
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.middleware.trustedhost import TrustedHostMiddleware
9
+ from slowapi import Limiter, _rate_limit_exceeded_handler
10
+ from slowapi.util import get_remote_address
11
+ from slowapi.errors import RateLimitExceeded
12
+ from pydantic import BaseModel, Field
13
+ from best import ModelConfig, IndonesianLLM, generate_text, _extract_thinking
14
+
15
+ # ── Load model ──────────────────────────────────────────────────────────────
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
 
18
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
19
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
20
 
21
+ checkpoint = torch.load("model.pt", map_location=device, weights_only=False)
 
 
 
22
  config = checkpoint['config']
 
 
23
  model = IndonesianLLM(config)
24
 
 
25
  state_dict = checkpoint['model_state_dict']
26
  if checkpoint.get('dtype') == 'fp16':
27
  state_dict = {k: v.float() if v.dtype == torch.float16 else v
 
29
 
30
  model.load_state_dict(state_dict)
31
  model.eval()
 
 
32
  model.to(device)
33
 
34
+ # ── Rate Limiter (slowapi) ───────────────────────────────────────────────────
35
+ limiter = Limiter(key_func=get_remote_address)
36
+
37
+ # ── IP Blacklist (in-memory, reset saat restart) ────────────────────────────
38
+ ip_blacklist: set = set()
39
+ ip_request_count: dict = defaultdict(list) # ip -> [timestamp, ...]
40
+
41
+ BLACKLIST_THRESHOLD = 100 # request dalam window ini β†’ blacklist
42
+ BLACKLIST_WINDOW = 60 # detik
43
+ BLACKLIST_DURATION = 3600 # banned 1 jam (simpan di set terpisah)
44
+ ip_banned_until: dict = {} # ip -> timestamp banned sampai kapan
45
+
46
+ # ── FastAPI setup ───────────────────────────────────────────────────────────
47
+ app = FastAPI(title="Indonesian LLM API")
48
+ app.state.limiter = limiter
49
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
50
+
51
+ # CORS β€” ganti origins sesuai domain kamu
52
+ app.add_middleware(
53
+ CORSMiddleware,
54
+ allow_origins=["https://nousai.netlify.app"], # ganti! jangan "*" di production
55
+ allow_methods=["POST", "GET"],
56
+ allow_headers=["*"],
57
+ )
58
+
59
+ # Trusted hosts β€” tolak request dengan Host header aneh
60
+ app.add_middleware(
61
+ TrustedHostMiddleware,
62
+ allowed_hosts=["yourdomain.com", "localhost", "127.0.0.1"]
63
+ )
64
+
65
+ # ── Middleware: DDoS / Flood Detection ──────────────────────────────────────
66
+ @app.middleware("http")
67
+ async def ddos_protection(request: Request, call_next):
68
+ ip = get_remote_address(request)
69
+ now = time.time()
70
+
71
+ # Cek apakah IP sedang dibanned
72
+ if ip in ip_banned_until:
73
+ if now < ip_banned_until[ip]:
74
+ remaining = int(ip_banned_until[ip] - now)
75
+ return HTTPException(
76
+ status_code=429,
77
+ detail=f"IP banned. Coba lagi dalam {remaining} detik."
78
+ )
79
+ else:
80
+ # Ban sudah habis
81
+ del ip_banned_until[ip]
82
+ ip_request_count[ip] = []
83
+
84
+ # Catat timestamp request ini
85
+ ip_request_count[ip].append(now)
86
+
87
+ # Bersihkan request yang sudah di luar window
88
+ ip_request_count[ip] = [
89
+ t for t in ip_request_count[ip]
90
+ if now - t < BLACKLIST_WINDOW
91
+ ]
92
+
93
+ # Jika terlalu banyak request β†’ ban
94
+ if len(ip_request_count[ip]) > BLACKLIST_THRESHOLD:
95
+ ip_banned_until[ip] = now + BLACKLIST_DURATION
96
+ ip_request_count[ip] = []
97
+ raise HTTPException(
98
+ status_code=429,
99
+ detail=f"Terlalu banyak request. IP dibanned selama {BLACKLIST_DURATION//60} menit."
100
+ )
101
+
102
+ response = await call_next(request)
103
+ return response
104
+
105
+ # ── Request/Response Schema ─────────────────────────────────────────────────
106
+ class ChatRequest(BaseModel):
107
+ message: str = Field(..., min_length=1, max_length=500) # batasi panjang input
108
+ max_tokens: int = Field(default=200, ge=10, le=500) # min 10, max 500
109
+ temperature: float = Field(default=0.7, ge=0.1, le=1.5)
110
+ show_thinking: bool = False
111
+
112
+ class ChatResponse(BaseModel):
113
+ answer: str
114
+ thinking: str | None = None
115
+ processing_time_ms: int
116
+
117
+ # ── API Key sederhana (opsional tapi direkomendasikan) ──────────────────────
118
+ API_KEYS = {"kunci-rahasia-kamu-123"} # ganti dengan key yang kuat
119
+
120
+ def verify_api_key(request: Request):
121
+ key = request.headers.get("X-API-Key")
122
+ if not key or key not in API_KEYS:
123
+ raise HTTPException(status_code=401, detail="API key tidak valid.")
124
+ return key
125
+
126
+ # ── Endpoints ───────────────────────────────────────────────────────────────
127
+ @app.get("/")
128
+ def health():
129
+ return {"status": "ok", "device": str(device)}
130
+
131
+ @app.post("/chat", response_model=ChatResponse)
132
+ @limiter.limit("20/minute") # max 10 request per menit per IP
133
+ @limiter.limit("100/hour") # max 50 request per jam per IP
134
+ async def chat(
135
+ req: ChatRequest,
136
+ request: Request,
137
+ _key: str = Depends(verify_api_key) # hapus baris ini jika tidak pakai API key
138
+ ):
139
+ start = time.time()
140
+
141
+ prompt = f"{req.message} <cot>"
142
  full = generate_text(
143
  model=model, tokenizer=tokenizer, prompt=prompt,
144
+ max_new_tokens=req.max_tokens, temperature=req.temperature,
145
  top_k=50, top_p=0.9, device=device
146
  )
147
  raw = full[len(prompt):].strip()
148
+ thinking, answer = _extract_thinking(raw)
149
+
150
+ elapsed_ms = int((time.time() - start) * 1000)
 
 
 
 
 
 
 
151
 
152
+ return ChatResponse(
153
+ answer=answer if answer else "Maaf, saya tidak mengerti.",
154
+ thinking=thinking if req.show_thinking else None,
155
+ processing_time_ms=elapsed_ms
156
+ )