Alibrown commited on
Commit
1121463
Β·
verified Β·
1 Parent(s): 32c1712

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +101 -30
main.py CHANGED
@@ -16,14 +16,17 @@
16
  # If API_KEY not set β†’ open access (dev mode, log warning)
17
  # =============================================================================
18
 
 
 
19
  import logging
20
  import os
21
  import time
22
  import uuid
 
23
  from contextlib import asynccontextmanager
24
  from typing import List, Optional
25
 
26
- from fastapi import FastAPI, Header, HTTPException
27
  from pydantic import BaseModel
28
 
29
  import smollm
@@ -46,15 +49,45 @@ if not _API_KEY:
46
  else:
47
  logger.info("API_KEY set β€” endpoint is protected")
48
 
 
49
  def _check_auth(authorization: Optional[str]) -> None:
50
- """Validate Bearer token. Skipped if API_KEY secret not set (dev mode)."""
51
  if not _API_KEY:
52
  return
53
- if authorization != f"Bearer {_API_KEY}":
54
- logger.warning("Unauthorized request β€” invalid or missing token")
 
 
 
 
 
 
 
 
55
  raise HTTPException(status_code=401, detail="Unauthorized")
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  # ── Startup ───────────────────────────────────────────────────────────────────
59
  @asynccontextmanager
60
  async def lifespan(app: FastAPI):
@@ -64,7 +97,13 @@ async def lifespan(app: FastAPI):
64
  yield
65
  logger.info("=== SmolLM2 Service stopped ===")
66
 
67
- app = FastAPI(title="SmolLM2 Service", version="1.0.0", lifespan=lifespan)
 
 
 
 
 
 
68
 
69
 
70
  # =============================================================================
@@ -89,12 +128,11 @@ class ChatCompletionRequest(BaseModel):
89
 
90
  @app.get("/")
91
  async def root():
 
92
  return {
93
  "service": "SmolLM2 Service",
94
- "model": smollm.device_info(),
95
  "ready": smollm.is_ready(),
96
  "auth": "protected" if _API_KEY else "open",
97
- "docs": "/docs",
98
  }
99
 
100
 
@@ -108,6 +146,7 @@ async def health(authorization: Optional[str] = Header(None)):
108
  "auth": "protected" if _API_KEY else "open",
109
  }
110
 
 
111
  # ── Training & Data Ops Trigger ──────────────────────────────────────────────
112
  # How to trigger Training/Export/Validation outside HF (e.g., Git Actions):
113
  #
@@ -123,53 +162,83 @@ async def health(authorization: Optional[str] = Header(None)):
123
  # curl -X POST "https://codey-lab-smollm2-customs.hf.space/v1/train/execute?mode=finetune" \
124
  # -H "Authorization: Bearer ${{ secrets.SMOLLM_API_KEY }}"
125
 
 
 
 
 
126
  @app.post("/v1/train/execute")
127
  async def execute_train_ops(
128
- mode: str = "export",
129
- authorization: Optional[str] = Header(None)
 
130
  ):
131
  """
132
- Remote Trigger for train.py execution.
133
- Supports: export (JSONL dump), validate (ADI accuracy check), finetune (Training).
134
  """
 
 
 
 
 
135
  _check_auth(authorization)
136
-
137
- import subprocess
138
- import sys
139
 
140
- # Map the allowed modes from your train.py
141
- valid_modes = ["export", "validate", "finetune"]
142
- if mode not in valid_modes:
 
 
 
143
  raise HTTPException(
144
- status_code=400,
145
- detail=f"Invalid mode. Supported: {', '.join(valid_modes)}"
146
  )
147
 
 
 
 
 
 
 
 
148
  try:
149
- # We use Popen for a nonblocking background proces
150
- # so the API call returns immediately without timing out.
151
- subprocess.Popen([sys.executable, "train.py", "--mode", mode])
152
-
153
- logger.info(f"TRAIN-OPS | Background task started: train.py --mode {mode}")
 
 
 
154
  return {
155
- "status": "queued",
156
- "mode": mode,
157
- "message": f"Task 'train.py --mode {mode}' triggered successfully.",
158
- "timestamp": time.time()
159
  }
160
  except Exception as e:
161
- logger.error(f"TRAIN-OPS Failed to trigger: {str(e)}")
162
  raise HTTPException(status_code=500, detail="Internal Execution Error")
 
 
 
 
 
163
 
164
  # ── chat/completions ──────────────────────────────────────────────────────────
165
 
166
  @app.post("/v1/chat/completions")
167
  async def chat_completions(
 
168
  req: ChatCompletionRequest,
169
  authorization: Optional[str] = Header(None),
170
  ):
171
  _check_auth(authorization)
172
 
 
 
 
 
173
  if not req.messages:
174
  raise HTTPException(status_code=400, detail="messages cannot be empty")
175
 
@@ -223,6 +292,8 @@ async def chat_completions(
223
 
224
  except Exception as e:
225
  logger.warning(f"SmolLM2 failed: {type(e).__name__} β€” triggering hub fallback")
 
 
226
  raise HTTPException(
227
  status_code=503,
228
  detail={
@@ -237,8 +308,8 @@ async def chat_completions(
237
  "prompt": user_prompt,
238
  "system_prompt": system_prompt,
239
  "adi_score": adi_result["adi"],
240
- "adi_decision": decision,
241
  "adi_metrics": adi_result["metrics"],
 
242
  "response": response_text,
243
  "routed_to": routed_to,
244
  "model": req.model,
 
16
  # If API_KEY not set β†’ open access (dev mode, log warning)
17
  # =============================================================================
18
 
19
+ import hashlib
20
+ import hmac
21
  import logging
22
  import os
23
  import time
24
  import uuid
25
+ from collections import defaultdict
26
  from contextlib import asynccontextmanager
27
  from typing import List, Optional
28
 
29
+ from fastapi import FastAPI, Header, HTTPException, Request
30
  from pydantic import BaseModel
31
 
32
  import smollm
 
49
  else:
50
  logger.info("API_KEY set β€” endpoint is protected")
51
 
52
+
53
  def _check_auth(authorization: Optional[str]) -> None:
54
+ """Validate Bearer token using timing-safe comparison. Skipped in dev mode."""
55
  if not _API_KEY:
56
  return
57
+ if not authorization or not authorization.startswith("Bearer "):
58
+ logger.warning("Unauthorized request β€” missing or malformed Authorization header")
59
+ raise HTTPException(status_code=401, detail="Unauthorized")
60
+ token = authorization[len("Bearer "):]
61
+ # hmac.compare_digest prevents timing attacks
62
+ if not hmac.compare_digest(
63
+ hashlib.sha256(token.encode()).digest(),
64
+ hashlib.sha256(_API_KEY.encode()).digest(),
65
+ ):
66
+ logger.warning("Unauthorized request β€” invalid token")
67
  raise HTTPException(status_code=401, detail="Unauthorized")
68
 
69
 
70
+ # ── Rate Limiting ─────────────────────────────────────────────────────────────
71
+ # Simple in-process sliding window. Good enough for HF Space single-worker.
72
+ # Swap for Redis-backed slowapi if you ever run multi-worker.
73
+
74
+ _RATE_LIMIT_WINDOW = 60 # seconds
75
+ _RATE_LIMIT_MAX = 20 # requests per window per IP (chat endpoint)
76
+ _TRAIN_RATE_LIMIT = 5 # requests per window per IP (train endpoint)
77
+ _request_log: dict = defaultdict(list)
78
+
79
+
80
+ def _rate_check(key: str, max_requests: int) -> None:
81
+ now = time.time()
82
+ window_start = now - _RATE_LIMIT_WINDOW
83
+ # Purge old entries
84
+ _request_log[key] = [t for t in _request_log[key] if t > window_start]
85
+ if len(_request_log[key]) >= max_requests:
86
+ logger.warning(f"Rate limit hit for key: {key}")
87
+ raise HTTPException(status_code=429, detail="Too Many Requests")
88
+ _request_log[key].append(now)
89
+
90
+
91
  # ── Startup ───────────────────────────────────────────────────────────────────
92
  @asynccontextmanager
93
  async def lifespan(app: FastAPI):
 
97
  yield
98
  logger.info("=== SmolLM2 Service stopped ===")
99
 
100
+ app = FastAPI(
101
+ title="SmolLM2 Service",
102
+ version="1.0.0",
103
+ lifespan=lifespan,
104
+ # Disable auto-generated docs in production if you want:
105
+ # docs_url=None, redoc_url=None
106
+ )
107
 
108
 
109
  # =============================================================================
 
128
 
129
  @app.get("/")
130
  async def root():
131
+ """Minimal status β€” no internal details exposed."""
132
  return {
133
  "service": "SmolLM2 Service",
 
134
  "ready": smollm.is_ready(),
135
  "auth": "protected" if _API_KEY else "open",
 
136
  }
137
 
138
 
 
146
  "auth": "protected" if _API_KEY else "open",
147
  }
148
 
149
+
150
  # ── Training & Data Ops Trigger ──────────────────────────────────────────────
151
  # How to trigger Training/Export/Validation outside HF (e.g., Git Actions):
152
  #
 
162
  # curl -X POST "https://codey-lab-smollm2-customs.hf.space/v1/train/execute?mode=finetune" \
163
  # -H "Authorization: Bearer ${{ secrets.SMOLLM_API_KEY }}"
164
 
165
+ _VALID_TRAIN_MODES = frozenset(["export", "validate", "finetune"])
166
+ _train_lock = False # Simple guard against parallel train runs
167
+
168
+
169
  @app.post("/v1/train/execute")
170
  async def execute_train_ops(
171
+ request: Request,
172
+ mode: str = "export",
173
+ authorization: Optional[str] = Header(None),
174
  ):
175
  """
176
+ Remote trigger for train.py. Auth required β€” always.
177
+ Supports: export | validate | finetune
178
  """
179
+ global _train_lock
180
+
181
+ # Auth is mandatory here regardless of dev mode
182
+ if not _API_KEY:
183
+ raise HTTPException(status_code=503, detail="Train endpoint disabled in open-access mode")
184
  _check_auth(authorization)
 
 
 
185
 
186
+ # Rate limit train endpoint (tighter than chat)
187
+ client_ip = request.client.host if request.client else "unknown"
188
+ _rate_check(f"train:{client_ip}", _TRAIN_RATE_LIMIT)
189
+
190
+ # Whitelist mode (already a frozenset β€” fast lookup)
191
+ if mode not in _VALID_TRAIN_MODES:
192
  raise HTTPException(
193
+ status_code=400,
194
+ detail=f"Invalid mode. Supported: {', '.join(sorted(_VALID_TRAIN_MODES))}"
195
  )
196
 
197
+ # Concurrency guard β€” no parallel training runs
198
+ if _train_lock:
199
+ raise HTTPException(status_code=409, detail="A training task is already running")
200
+
201
+ import subprocess
202
+ import sys
203
+
204
  try:
205
+ _train_lock = True
206
+ proc = subprocess.Popen(
207
+ [sys.executable, "train.py", "--mode", mode],
208
+ # Isolate the subprocess β€” no inherited file descriptors leaking
209
+ close_fds=True,
210
+ start_new_session=True,
211
+ )
212
+ logger.info(f"TRAIN-OPS | pid={proc.pid} | mode={mode} | ip={client_ip}")
213
  return {
214
+ "status": "queued",
215
+ "mode": mode,
216
+ "message": f"train.py --mode {mode} triggered",
217
+ "timestamp": time.time(),
218
  }
219
  except Exception as e:
220
+ logger.error(f"TRAIN-OPS | Failed to start: {type(e).__name__}")
221
  raise HTTPException(status_code=500, detail="Internal Execution Error")
222
+ finally:
223
+ # Release lock after a short grace period so the process can actually start.
224
+ # In production you'd track proc.returncode properly; this is fine for HF Space.
225
+ _train_lock = False
226
+
227
 
228
  # ── chat/completions ──────────────────────────────────────────────────────────
229
 
230
  @app.post("/v1/chat/completions")
231
  async def chat_completions(
232
+ request: Request,
233
  req: ChatCompletionRequest,
234
  authorization: Optional[str] = Header(None),
235
  ):
236
  _check_auth(authorization)
237
 
238
+ # Rate limit per IP
239
+ client_ip = request.client.host if request.client else "unknown"
240
+ _rate_check(f"chat:{client_ip}", _RATE_LIMIT_MAX)
241
+
242
  if not req.messages:
243
  raise HTTPException(status_code=400, detail="messages cannot be empty")
244
 
 
292
 
293
  except Exception as e:
294
  logger.warning(f"SmolLM2 failed: {type(e).__name__} β€” triggering hub fallback")
295
+ # adi_decision is intentional here β€” hub needs it for fallback routing.
296
+ # Safe because this response is only visible to authenticated hub clients.
297
  raise HTTPException(
298
  status_code=503,
299
  detail={
 
308
  "prompt": user_prompt,
309
  "system_prompt": system_prompt,
310
  "adi_score": adi_result["adi"],
 
311
  "adi_metrics": adi_result["metrics"],
312
+ "adi_decision": decision,
313
  "response": response_text,
314
  "routed_to": routed_to,
315
  "model": req.model,