likhonsheikh commited on
Commit
dffa5d7
·
verified ·
1 Parent(s): f09fb4b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +122 -19
app.py CHANGED
@@ -6,6 +6,9 @@ Lightweight CPU-based implementation for Hugging Face Spaces
6
  import os
7
  import time
8
  import uuid
 
 
 
9
  from typing import List, Optional, Union
10
  from contextlib import asynccontextmanager
11
 
@@ -18,6 +21,46 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
18
  from threading import Thread
19
  import json
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ============== Configuration ==============
22
  MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" # Ultra-lightweight 135M model
23
  MAX_TOKENS_DEFAULT = 1024
@@ -31,21 +74,29 @@ tokenizer = None
31
  async def lifespan(app: FastAPI):
32
  """Load model on startup"""
33
  global model, tokenizer
34
- print(f"Loading model: {MODEL_ID}")
35
-
36
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- MODEL_ID,
39
- torch_dtype=torch.float32,
40
- device_map=DEVICE,
41
- low_cpu_mem_usage=True
42
- )
43
- model.eval()
44
- print("Model loaded successfully!")
 
 
 
 
 
 
 
45
 
46
  yield
47
 
48
  # Cleanup
 
49
  del model, tokenizer
50
 
51
  app = FastAPI(
@@ -64,6 +115,24 @@ app.add_middleware(
64
  allow_headers=["*"],
65
  )
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # ============== Pydantic Models (Anthropic-Compatible) ==============
68
 
69
  class ContentBlock(BaseModel):
@@ -143,16 +212,19 @@ def generate_id() -> str:
143
  @app.get("/")
144
  async def root():
145
  """Health check endpoint"""
 
146
  return {
147
  "status": "healthy",
148
  "model": MODEL_ID,
149
  "api_version": "2023-06-01",
150
- "compatibility": "anthropic-messages-api"
 
151
  }
152
 
153
  @app.get("/v1/models")
154
  async def list_models():
155
  """List available models (Anthropic-compatible)"""
 
156
  return {
157
  "object": "list",
158
  "data": [
@@ -166,6 +238,22 @@ async def list_models():
166
  ]
167
  }
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  @app.post("/v1/messages")
170
  async def create_message(
171
  request: MessageRequest,
@@ -175,18 +263,25 @@ async def create_message(
175
  """
176
  Create a message (Anthropic Messages API compatible)
177
  """
 
 
 
178
  try:
179
  # Format the prompt
180
  prompt = format_messages(request.messages, request.system)
 
181
 
182
  # Tokenize
183
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
184
  input_token_count = inputs.input_ids.shape[1]
 
185
 
186
  if request.stream:
187
- return await stream_response(request, inputs, input_token_count)
 
188
 
189
  # Generate
 
190
  with torch.no_grad():
191
  outputs = model.generate(
192
  **inputs,
@@ -198,15 +293,19 @@ async def create_message(
198
  pad_token_id=tokenizer.eos_token_id,
199
  eos_token_id=tokenizer.eos_token_id,
200
  )
 
201
 
202
  # Decode only new tokens
203
  generated_tokens = outputs[0][input_token_count:]
204
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
205
  output_token_count = len(generated_tokens)
206
 
 
 
 
207
  # Build response
208
  response = MessageResponse(
209
- id=generate_id(),
210
  content=[ContentBlock(type="text", text=generated_text.strip())],
211
  model=request.model,
212
  stop_reason="end_turn",
@@ -219,13 +318,12 @@ async def create_message(
219
  return response
220
 
221
  except Exception as e:
 
222
  raise HTTPException(status_code=500, detail=str(e))
223
 
224
- async def stream_response(request: MessageRequest, inputs, input_token_count: int):
225
  """Stream response using SSE (Server-Sent Events)"""
226
 
227
- message_id = generate_id()
228
-
229
  async def generate():
230
  # Send message_start event
231
  start_event = {
@@ -267,6 +365,7 @@ async def stream_response(request: MessageRequest, inputs, input_token_count: in
267
  }
268
 
269
  # Run generation in a thread
 
270
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
271
  thread.start()
272
 
@@ -282,6 +381,9 @@ async def stream_response(request: MessageRequest, inputs, input_token_count: in
282
  yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
283
 
284
  thread.join()
 
 
 
285
 
286
  # Send content_block_stop
287
  block_stop = {"type": "content_block_stop", "index": 0}
@@ -314,13 +416,14 @@ async def count_tokens(request: MessageRequest):
314
  """Count tokens for a message request"""
315
  prompt = format_messages(request.messages, request.system)
316
  tokens = tokenizer.encode(prompt)
 
317
  return {"input_tokens": len(tokens)}
318
 
319
  # Health check
320
  @app.get("/health")
321
  async def health():
322
- return {"status": "ok", "model_loaded": model is not None}
323
 
324
  if __name__ == "__main__":
325
  import uvicorn
326
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
6
  import os
7
  import time
8
  import uuid
9
+ import logging
10
+ from datetime import datetime
11
+ from logging.handlers import RotatingFileHandler
12
  from typing import List, Optional, Union
13
  from contextlib import asynccontextmanager
14
 
 
21
  from threading import Thread
22
  import json
23
 
24
+ # ============== Logging Configuration ==============
25
+ LOG_DIR = "/tmp/logs"
26
+ os.makedirs(LOG_DIR, exist_ok=True)
27
+ LOG_FILE = os.path.join(LOG_DIR, "api.log")
28
+
29
+ # Create formatters
30
+ log_format = logging.Formatter(
31
+ '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
32
+ datefmt='%Y-%m-%d %H:%M:%S'
33
+ )
34
+
35
+ # File handler with rotation (10MB max, keep 5 backups)
36
+ file_handler = RotatingFileHandler(
37
+ LOG_FILE,
38
+ maxBytes=10*1024*1024,
39
+ backupCount=5,
40
+ encoding='utf-8'
41
+ )
42
+ file_handler.setFormatter(log_format)
43
+ file_handler.setLevel(logging.DEBUG)
44
+
45
+ # Console handler
46
+ console_handler = logging.StreamHandler()
47
+ console_handler.setFormatter(log_format)
48
+ console_handler.setLevel(logging.INFO)
49
+
50
+ # Root logger
51
+ logging.basicConfig(level=logging.DEBUG, handlers=[file_handler, console_handler])
52
+ logger = logging.getLogger("anthropic-api")
53
+
54
+ # Also capture uvicorn logs
55
+ for uvicorn_logger in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
56
+ uv_log = logging.getLogger(uvicorn_logger)
57
+ uv_log.handlers = [file_handler, console_handler]
58
+
59
+ logger.info("=" * 60)
60
+ logger.info(f"Application Startup at {datetime.now().isoformat()}")
61
+ logger.info(f"Log file: {LOG_FILE}")
62
+ logger.info("=" * 60)
63
+
64
  # ============== Configuration ==============
65
  MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" # Ultra-lightweight 135M model
66
  MAX_TOKENS_DEFAULT = 1024
 
74
  async def lifespan(app: FastAPI):
75
  """Load model on startup"""
76
  global model, tokenizer
77
+ logger.info(f"Loading model: {MODEL_ID}")
78
+
79
+ try:
80
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
81
+ logger.info("Tokenizer loaded successfully")
82
+
83
+ model = AutoModelForCausalLM.from_pretrained(
84
+ MODEL_ID,
85
+ torch_dtype=torch.float32,
86
+ device_map=DEVICE,
87
+ low_cpu_mem_usage=True
88
+ )
89
+ model.eval()
90
+ logger.info("Model loaded successfully!")
91
+ logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
92
+ except Exception as e:
93
+ logger.error(f"Failed to load model: {e}", exc_info=True)
94
+ raise
95
 
96
  yield
97
 
98
  # Cleanup
99
+ logger.info("Shutting down, cleaning up model...")
100
  del model, tokenizer
101
 
102
  app = FastAPI(
 
115
  allow_headers=["*"],
116
  )
117
 
118
+ # Request logging middleware
119
+ @app.middleware("http")
120
+ async def log_requests(request: Request, call_next):
121
+ request_id = str(uuid.uuid4())[:8]
122
+ start_time = time.time()
123
+
124
+ logger.info(f"[{request_id}] {request.method} {request.url.path} - Started")
125
+
126
+ try:
127
+ response = await call_next(request)
128
+ duration = (time.time() - start_time) * 1000
129
+ logger.info(f"[{request_id}] {request.method} {request.url.path} - {response.status_code} ({duration:.2f}ms)")
130
+ return response
131
+ except Exception as e:
132
+ duration = (time.time() - start_time) * 1000
133
+ logger.error(f"[{request_id}] {request.method} {request.url.path} - Error: {e} ({duration:.2f}ms)")
134
+ raise
135
+
136
  # ============== Pydantic Models (Anthropic-Compatible) ==============
137
 
138
  class ContentBlock(BaseModel):
 
212
  @app.get("/")
213
  async def root():
214
  """Health check endpoint"""
215
+ logger.debug("Root endpoint accessed")
216
  return {
217
  "status": "healthy",
218
  "model": MODEL_ID,
219
  "api_version": "2023-06-01",
220
+ "compatibility": "anthropic-messages-api",
221
+ "log_file": LOG_FILE
222
  }
223
 
224
  @app.get("/v1/models")
225
  async def list_models():
226
  """List available models (Anthropic-compatible)"""
227
+ logger.debug("Models list requested")
228
  return {
229
  "object": "list",
230
  "data": [
 
238
  ]
239
  }
240
 
241
+ @app.get("/logs")
242
+ async def get_logs(lines: int = 100):
243
+ """Get recent log entries"""
244
+ try:
245
+ with open(LOG_FILE, 'r') as f:
246
+ all_lines = f.readlines()
247
+ recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
248
+ return {
249
+ "log_file": LOG_FILE,
250
+ "total_lines": len(all_lines),
251
+ "returned_lines": len(recent_lines),
252
+ "logs": "".join(recent_lines)
253
+ }
254
+ except FileNotFoundError:
255
+ return {"error": "Log file not found", "log_file": LOG_FILE}
256
+
257
  @app.post("/v1/messages")
258
  async def create_message(
259
  request: MessageRequest,
 
263
  """
264
  Create a message (Anthropic Messages API compatible)
265
  """
266
+ message_id = generate_id()
267
+ logger.info(f"[{message_id}] Creating message - model: {request.model}, max_tokens: {request.max_tokens}, stream: {request.stream}")
268
+
269
  try:
270
  # Format the prompt
271
  prompt = format_messages(request.messages, request.system)
272
+ logger.debug(f"[{message_id}] Prompt length: {len(prompt)} chars")
273
 
274
  # Tokenize
275
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
276
  input_token_count = inputs.input_ids.shape[1]
277
+ logger.info(f"[{message_id}] Input tokens: {input_token_count}")
278
 
279
  if request.stream:
280
+ logger.info(f"[{message_id}] Starting streaming response")
281
+ return await stream_response(request, inputs, input_token_count, message_id)
282
 
283
  # Generate
284
+ gen_start = time.time()
285
  with torch.no_grad():
286
  outputs = model.generate(
287
  **inputs,
 
293
  pad_token_id=tokenizer.eos_token_id,
294
  eos_token_id=tokenizer.eos_token_id,
295
  )
296
+ gen_time = time.time() - gen_start
297
 
298
  # Decode only new tokens
299
  generated_tokens = outputs[0][input_token_count:]
300
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
301
  output_token_count = len(generated_tokens)
302
 
303
+ tokens_per_sec = output_token_count / gen_time if gen_time > 0 else 0
304
+ logger.info(f"[{message_id}] Generated {output_token_count} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)")
305
+
306
  # Build response
307
  response = MessageResponse(
308
+ id=message_id,
309
  content=[ContentBlock(type="text", text=generated_text.strip())],
310
  model=request.model,
311
  stop_reason="end_turn",
 
318
  return response
319
 
320
  except Exception as e:
321
+ logger.error(f"[{message_id}] Error creating message: {e}", exc_info=True)
322
  raise HTTPException(status_code=500, detail=str(e))
323
 
324
+ async def stream_response(request: MessageRequest, inputs, input_token_count: int, message_id: str):
325
  """Stream response using SSE (Server-Sent Events)"""
326
 
 
 
327
  async def generate():
328
  # Send message_start event
329
  start_event = {
 
365
  }
366
 
367
  # Run generation in a thread
368
+ gen_start = time.time()
369
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
370
  thread.start()
371
 
 
381
  yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
382
 
383
  thread.join()
384
+ gen_time = time.time() - gen_start
385
+ tokens_per_sec = output_tokens / gen_time if gen_time > 0 else 0
386
+ logger.info(f"[{message_id}] Stream completed: {output_tokens} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)")
387
 
388
  # Send content_block_stop
389
  block_stop = {"type": "content_block_stop", "index": 0}
 
416
  """Count tokens for a message request"""
417
  prompt = format_messages(request.messages, request.system)
418
  tokens = tokenizer.encode(prompt)
419
+ logger.debug(f"Token count request: {len(tokens)} tokens")
420
  return {"input_tokens": len(tokens)}
421
 
422
  # Health check
423
  @app.get("/health")
424
  async def health():
425
+ return {"status": "ok", "model_loaded": model is not None, "log_file": LOG_FILE}
426
 
427
  if __name__ == "__main__":
428
  import uvicorn
429
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_config=None)