FaiziRBLX commited on
Commit
907d439
Β·
verified Β·
1 Parent(s): a93f50a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -72
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
- import time
3
- import hashlib
4
- from collections import defaultdict
5
  from transformers import AutoTokenizer
6
  from fastapi import FastAPI, Request, HTTPException, Depends
7
  from fastapi.middleware.cors import CORSMiddleware
@@ -10,102 +10,116 @@ from slowapi import Limiter, _rate_limit_exceeded_handler
10
  from slowapi.util import get_remote_address
11
  from slowapi.errors import RateLimitExceeded
12
  from pydantic import BaseModel, Field
 
13
  from best import ModelConfig, IndonesianLLM, generate_text, _extract_thinking
 
14
 
15
- # ── Load model ──────────────────────────────────────────────────────────────
16
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
 
18
- tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
19
- tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
 
 
 
 
20
 
21
- checkpoint = torch.load("model.pt", map_location=device, weights_only=False)
22
- config = checkpoint['config']
23
- model = IndonesianLLM(config)
24
 
25
- state_dict = checkpoint['model_state_dict']
26
- if checkpoint.get('dtype') == 'fp16':
27
- state_dict = {k: v.float() if v.dtype == torch.float16 else v
28
- for k, v in state_dict.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- model.load_state_dict(state_dict)
31
  model.eval()
32
- model.to(device)
33
 
34
- # ── Rate Limiter (slowapi) ───────────────────────────────────────────────────
35
  limiter = Limiter(key_func=get_remote_address)
 
 
 
 
 
36
 
37
- # ── IP Blacklist (in-memory, reset saat restart) ────────────────────────────
38
- ip_blacklist: set = set()
39
- ip_request_count: dict = defaultdict(list) # ip -> [timestamp, ...]
40
-
41
- BLACKLIST_THRESHOLD = 100 # request dalam window ini β†’ blacklist
42
- BLACKLIST_WINDOW = 60 # detik
43
- BLACKLIST_DURATION = 3600 # banned 1 jam (simpan di set terpisah)
44
- ip_banned_until: dict = {} # ip -> timestamp banned sampai kapan
45
-
46
- # ── FastAPI setup ──────────────────────────────────────────��────────────────
47
  app = FastAPI(title="Indonesian LLM API")
48
  app.state.limiter = limiter
49
  app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
50
 
51
- # CORS β€” ganti origins sesuai domain kamu
52
  app.add_middleware(
53
  CORSMiddleware,
54
- allow_origins=["https://nousai.netlify.app"], # ganti! jangan "*" di production
55
  allow_methods=["POST", "GET"],
56
  allow_headers=["*"],
57
  )
58
 
59
- # Trusted hosts β€” tolak request dengan Host header aneh
60
- app.add_middleware(
61
- TrustedHostMiddleware,
62
- allowed_hosts=["yourdomain.com", "localhost", "127.0.0.1"]
63
- )
64
-
65
- # ── Middleware: DDoS / Flood Detection ──────────────────────────────────────
66
  @app.middleware("http")
67
  async def ddos_protection(request: Request, call_next):
68
- ip = get_remote_address(request)
69
  now = time.time()
70
 
71
- # Cek apakah IP sedang dibanned
72
  if ip in ip_banned_until:
73
  if now < ip_banned_until[ip]:
74
  remaining = int(ip_banned_until[ip] - now)
75
- return HTTPException(
76
- status_code=429,
77
- detail=f"IP banned. Coba lagi dalam {remaining} detik."
78
- )
79
  else:
80
- # Ban sudah habis
81
  del ip_banned_until[ip]
82
  ip_request_count[ip] = []
83
 
84
- # Catat timestamp request ini
85
  ip_request_count[ip].append(now)
 
86
 
87
- # Bersihkan request yang sudah di luar window
88
- ip_request_count[ip] = [
89
- t for t in ip_request_count[ip]
90
- if now - t < BLACKLIST_WINDOW
91
- ]
92
-
93
- # Jika terlalu banyak request β†’ ban
94
  if len(ip_request_count[ip]) > BLACKLIST_THRESHOLD:
95
  ip_banned_until[ip] = now + BLACKLIST_DURATION
96
  ip_request_count[ip] = []
97
- raise HTTPException(
98
- status_code=429,
99
- detail=f"Terlalu banyak request. IP dibanned selama {BLACKLIST_DURATION//60} menit."
100
- )
101
 
102
- response = await call_next(request)
103
- return response
104
 
105
- # ── Request/Response Schema ─────────────────────────────────────────────────
106
  class ChatRequest(BaseModel):
107
- message: str = Field(..., min_length=1, max_length=500) # batasi panjang input
108
- max_tokens: int = Field(default=200, ge=10, le=500) # min 10, max 500
109
  temperature: float = Field(default=0.7, ge=0.1, le=1.5)
110
  show_thinking: bool = False
111
 
@@ -114,31 +128,30 @@ class ChatResponse(BaseModel):
114
  thinking: str | None = None
115
  processing_time_ms: int
116
 
117
- # ── API Key sederhana (opsional tapi direkomendasikan) ──────────────────────
118
- API_KEYS = {"kunci-rahasia-kamu-123"} # ganti dengan key yang kuat
119
 
120
  def verify_api_key(request: Request):
121
  key = request.headers.get("X-API-Key")
122
  if not key or key not in API_KEYS:
123
- raise HTTPException(status_code=401, detail="API key tidak valid.")
124
  return key
125
 
126
- # ── Endpoints ───────────────────────────────────────────────────────────────
127
  @app.get("/")
128
  def health():
129
  return {"status": "ok", "device": str(device)}
130
 
131
  @app.post("/chat", response_model=ChatResponse)
132
- @limiter.limit("20/minute") # max 10 request per menit per IP
133
- @limiter.limit("100/hour") # max 50 request per jam per IP
134
  async def chat(
135
  req: ChatRequest,
136
  request: Request,
137
- _key: str = Depends(verify_api_key) # hapus baris ini jika tidak pakai API key
138
  ):
139
- start = time.time()
140
-
141
  prompt = f"{req.message} <cot>"
 
142
  full = generate_text(
143
  model=model, tokenizer=tokenizer, prompt=prompt,
144
  max_new_tokens=req.max_tokens, temperature=req.temperature,
@@ -147,10 +160,8 @@ async def chat(
147
  raw = full[len(prompt):].strip()
148
  thinking, answer = _extract_thinking(raw)
149
 
150
- elapsed_ms = int((time.time() - start) * 1000)
151
-
152
  return ChatResponse(
153
  answer=answer if answer else "Maaf, saya tidak mengerti.",
154
  thinking=thinking if req.show_thinking else None,
155
- processing_time_ms=elapsed_ms
156
  )
 
1
  import torch
2
+ import os
3
+ import logging
4
+ import gc
5
  from transformers import AutoTokenizer
6
  from fastapi import FastAPI, Request, HTTPException, Depends
7
  from fastapi.middleware.cors import CORSMiddleware
 
10
  from slowapi.util import get_remote_address
11
  from slowapi.errors import RateLimitExceeded
12
  from pydantic import BaseModel, Field
13
+ from collections import defaultdict
14
  from best import ModelConfig, IndonesianLLM, generate_text, _extract_thinking
15
+ import time
16
 
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
 
20
+ # ── Cek file ────────────────────────────────────────────
21
+ logger.info(f"model.pt ada: {os.path.exists('model.pt')}")
22
+ if os.path.exists('model.pt'):
23
+ logger.info(f"model.pt size: {os.path.getsize('model.pt') / 1e6:.1f} MB")
24
+ else:
25
+ raise FileNotFoundError("model.pt tidak ditemukan! Upload dulu ke Space.")
26
 
27
+ # ── Device ──────────────────────────────────────────────
28
+ device = torch.device('cpu') # HF Spaces free = CPU only
29
+ logger.info(f"Device: {device}")
30
 
31
+ # ── Tokenizer ───────────────────────────────────────────
32
+ logger.info("Loading tokenizer...")
33
+ tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
34
+ tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
35
+ logger.info("Tokenizer OK")
36
+
37
+ # ── Model ───────────────────────────────────────────────
38
+ logger.info("Loading checkpoint...")
39
+ try:
40
+ checkpoint = torch.load("model.pt", map_location='cpu', weights_only=False)
41
+ logger.info(f"Checkpoint keys: {list(checkpoint.keys())}")
42
+ except Exception as e:
43
+ logger.error(f"GAGAL load checkpoint: {e}")
44
+ raise
45
+
46
+ logger.info("Building model...")
47
+ try:
48
+ config = checkpoint['config']
49
+ model = IndonesianLLM(config)
50
+ logger.info(f"Model params: {model.count_parameters():,}")
51
+ except Exception as e:
52
+ logger.error(f"GAGAL buat model: {e}")
53
+ raise
54
+
55
+ logger.info("Loading weights...")
56
+ try:
57
+ state_dict = checkpoint['model_state_dict']
58
+ # Konversi fp16 β†’ fp32 in-place (hemat RAM)
59
+ for k in list(state_dict.keys()):
60
+ if state_dict[k].dtype == torch.float16:
61
+ state_dict[k] = state_dict[k].float()
62
+ model.load_state_dict(state_dict)
63
+ logger.info("Weights OK")
64
+ except Exception as e:
65
+ logger.error(f"GAGAL load weights: {e}")
66
+ raise
67
+
68
+ # Bebaskan RAM
69
+ del checkpoint, state_dict
70
+ gc.collect()
71
+ logger.info(f"RAM setelah cleanup: {torch.cuda.memory_allocated()/1e6:.1f} MB (GPU)" if torch.cuda.is_available() else "RAM cleanup done")
72
 
 
73
  model.eval()
74
+ logger.info("Model siap!")
75
 
76
+ # ── Rate limiter ─────────────────────────────────────────
77
  limiter = Limiter(key_func=get_remote_address)
78
+ ip_request_count: dict = defaultdict(list)
79
+ ip_banned_until: dict = {}
80
+ BLACKLIST_THRESHOLD = 100
81
+ BLACKLIST_WINDOW = 60
82
+ BLACKLIST_DURATION = 3600
83
 
84
+ # ── FastAPI ──────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
85
  app = FastAPI(title="Indonesian LLM API")
86
  app.state.limiter = limiter
87
  app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
88
 
 
89
  app.add_middleware(
90
  CORSMiddleware,
91
+ allow_origins=["*"], # ganti dengan domain kamu di production
92
  allow_methods=["POST", "GET"],
93
  allow_headers=["*"],
94
  )
95
 
 
 
 
 
 
 
 
96
  @app.middleware("http")
97
  async def ddos_protection(request: Request, call_next):
98
+ ip = get_remote_address(request)
99
  now = time.time()
100
 
 
101
  if ip in ip_banned_until:
102
  if now < ip_banned_until[ip]:
103
  remaining = int(ip_banned_until[ip] - now)
104
+ raise HTTPException(429, f"IP banned. Coba lagi dalam {remaining}s.")
 
 
 
105
  else:
 
106
  del ip_banned_until[ip]
107
  ip_request_count[ip] = []
108
 
 
109
  ip_request_count[ip].append(now)
110
+ ip_request_count[ip] = [t for t in ip_request_count[ip] if now - t < BLACKLIST_WINDOW]
111
 
 
 
 
 
 
 
 
112
  if len(ip_request_count[ip]) > BLACKLIST_THRESHOLD:
113
  ip_banned_until[ip] = now + BLACKLIST_DURATION
114
  ip_request_count[ip] = []
115
+ raise HTTPException(429, f"Terlalu banyak request. Banned {BLACKLIST_DURATION//60} menit.")
 
 
 
116
 
117
+ return await call_next(request)
 
118
 
119
+ # ── Schema ───────────────────────────────────────────────
120
  class ChatRequest(BaseModel):
121
+ message: str = Field(..., min_length=1, max_length=500)
122
+ max_tokens: int = Field(default=200, ge=10, le=500)
123
  temperature: float = Field(default=0.7, ge=0.1, le=1.5)
124
  show_thinking: bool = False
125
 
 
128
  thinking: str | None = None
129
  processing_time_ms: int
130
 
131
+ API_KEYS = {"kunci-rahasia-kamu-123"} # ← ganti!
 
132
 
133
  def verify_api_key(request: Request):
134
  key = request.headers.get("X-API-Key")
135
  if not key or key not in API_KEYS:
136
+ raise HTTPException(401, "API key tidak valid.")
137
  return key
138
 
139
+ # ── Endpoints ─────────────────────────────────────────────
140
  @app.get("/")
141
  def health():
142
  return {"status": "ok", "device": str(device)}
143
 
144
  @app.post("/chat", response_model=ChatResponse)
145
+ @limiter.limit("10/minute")
146
+ @limiter.limit("50/hour")
147
  async def chat(
148
  req: ChatRequest,
149
  request: Request,
150
+ _key: str = Depends(verify_api_key)
151
  ):
152
+ start = time.time()
 
153
  prompt = f"{req.message} <cot>"
154
+
155
  full = generate_text(
156
  model=model, tokenizer=tokenizer, prompt=prompt,
157
  max_new_tokens=req.max_tokens, temperature=req.temperature,
 
160
  raw = full[len(prompt):].strip()
161
  thinking, answer = _extract_thinking(raw)
162
 
 
 
163
  return ChatResponse(
164
  answer=answer if answer else "Maaf, saya tidak mengerti.",
165
  thinking=thinking if req.show_thinking else None,
166
+ processing_time_ms=int((time.time() - start) * 1000)
167
  )