FaiziRBLX commited on
Commit
6d3ad04
Β·
verified Β·
1 Parent(s): d50a66d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -33
app.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import logging
6
  from collections import defaultdict
7
  from transformers import AutoTokenizer
8
- from fastapi import FastAPI, Request, HTTPException, Depends
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from slowapi import Limiter, _rate_limit_exceeded_handler
11
  from slowapi.util import get_remote_address
@@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
21
  device = torch.device('cpu')
22
 
23
  logger.info(f"model.pt ada: {os.path.exists('model.pt')}")
 
24
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
25
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
26
 
@@ -39,22 +40,44 @@ gc.collect()
39
  model.eval()
40
  logger.info("Model siap!")
41
 
42
- # ── FastAPI (untuk /api/chat endpoint) ───────────────────
43
  limiter = Limiter(key_func=get_remote_address)
44
  ip_request_count: dict = defaultdict(list)
45
  ip_banned_until: dict = {}
46
 
47
- fastapi_app = FastAPI()
48
- fastapi_app.state.limiter = limiter
49
- fastapi_app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
50
- fastapi_app.add_middleware(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  CORSMiddleware,
52
  allow_origins=["*"],
53
  allow_methods=["POST", "GET"],
54
  allow_headers=["*"],
55
  )
56
 
57
- @fastapi_app.middleware("http")
58
  async def ddos_protection(request: Request, call_next):
59
  ip = get_remote_address(request)
60
  now = time.time()
@@ -81,19 +104,17 @@ class ChatResponse(BaseModel):
81
  thinking: str | None = None
82
  processing_time_ms: int
83
 
84
- API_KEYS = {"kunci-rahasia-kamu-123"} # ← ganti!
85
-
86
  def verify_api_key(request: Request):
87
  key = request.headers.get("X-API-Key")
88
  if not key or key not in API_KEYS:
89
  raise HTTPException(401, "API key tidak valid.")
90
  return key
91
 
92
- @fastapi_app.get("/api/health")
93
  def health():
94
  return {"status": "ok", "device": str(device)}
95
 
96
- @fastapi_app.post("/api/chat", response_model=ChatResponse)
97
  @limiter.limit("10/minute")
98
  @limiter.limit("50/hour")
99
  async def api_chat(
@@ -116,25 +137,5 @@ async def api_chat(
116
  processing_time_ms=int((time.time() - start) * 1000)
117
  )
118
 
119
- # ── Gradio UI ────────────────────────────────────────────
120
- def gradio_chat(message, history):
121
- prompt = f"{message} <cot>"
122
- full = generate_text(
123
- model=model, tokenizer=tokenizer, prompt=prompt,
124
- max_new_tokens=200, temperature=0.7,
125
- top_k=50, top_p=0.9, device=device
126
- )
127
- raw = full[len(prompt):].strip()
128
- _, answer = _extract_thinking(raw)
129
- return answer if answer else "Maaf, saya tidak mengerti."
130
-
131
- gradio_ui = gr.ChatInterface(
132
- fn=gradio_chat,
133
- title="Indonesian LLM",
134
- description="Chat dengan model bahasa Indonesia"
135
- )
136
-
137
- # ── Mount FastAPI ke Gradio ───────────────────────────────
138
- # Ini kuncinya: Gradio expose FastAPI, kita tinggal mount route kita
139
- demo = gr.mount_gradio_app(fastapi_app, gradio_ui, path="/")
140
-
 
5
  import logging
6
  from collections import defaultdict
7
  from transformers import AutoTokenizer
8
+ from fastapi import Request, HTTPException, Depends
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from slowapi import Limiter, _rate_limit_exceeded_handler
11
  from slowapi.util import get_remote_address
 
21
  device = torch.device('cpu')
22
 
23
  logger.info(f"model.pt ada: {os.path.exists('model.pt')}")
24
+
25
  tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased")
26
  tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]})
27
 
 
40
  model.eval()
41
  logger.info("Model siap!")
42
 
43
+ # ── Rate limiter ─────────────────────────────────────────
44
  limiter = Limiter(key_func=get_remote_address)
45
  ip_request_count: dict = defaultdict(list)
46
  ip_banned_until: dict = {}
47
 
48
+ API_KEYS = {"kunci-rahasia-kamu-123"} # ← ganti!
49
+
50
+ # ── Gradio UI ────────────────────────────────────────────
51
+ def gradio_chat(message, history):
52
+ prompt = f"{message} <cot>"
53
+ full = generate_text(
54
+ model=model, tokenizer=tokenizer, prompt=prompt,
55
+ max_new_tokens=200, temperature=0.7,
56
+ top_k=50, top_p=0.9, device=device
57
+ )
58
+ raw = full[len(prompt):].strip()
59
+ _, answer = _extract_thinking(raw)
60
+ return answer if answer else "Maaf, saya tidak mengerti."
61
+
62
+ demo = gr.ChatInterface(
63
+ fn=gradio_chat,
64
+ title="Indonesian LLM",
65
+ description="Chat dengan model bahasa Indonesia"
66
+ )
67
+
68
+ # ── Tambah API route ke Gradio's FastAPI ─────────────────
69
+ app = demo.app # Gradio expose FastAPI internal di sini
70
+
71
+ app.state.limiter = limiter
72
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
73
+ app.add_middleware(
74
  CORSMiddleware,
75
  allow_origins=["*"],
76
  allow_methods=["POST", "GET"],
77
  allow_headers=["*"],
78
  )
79
 
80
+ @app.middleware("http")
81
  async def ddos_protection(request: Request, call_next):
82
  ip = get_remote_address(request)
83
  now = time.time()
 
104
  thinking: str | None = None
105
  processing_time_ms: int
106
 
 
 
107
  def verify_api_key(request: Request):
108
  key = request.headers.get("X-API-Key")
109
  if not key or key not in API_KEYS:
110
  raise HTTPException(401, "API key tidak valid.")
111
  return key
112
 
113
+ @app.get("/api/health")
114
  def health():
115
  return {"status": "ok", "device": str(device)}
116
 
117
+ @app.post("/api/chat", response_model=ChatResponse)
118
  @limiter.limit("10/minute")
119
  @limiter.limit("50/hour")
120
  async def api_chat(
 
137
  processing_time_ms=int((time.time() - start) * 1000)
138
  )
139
 
140
+ # ── Launch ───────────────────────────────────────────────
141
+ demo.launch()