likhonsheikh commited on
Commit
49560dc
·
verified ·
1 Parent(s): 5654ea3

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +243 -115
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Anthropic-Compatible API Endpoint
3
  Lightweight CPU-based implementation for Hugging Face Spaces
 
4
  """
5
 
6
  import os
@@ -9,7 +10,7 @@ import uuid
9
  import logging
10
  from datetime import datetime
11
  from logging.handlers import RotatingFileHandler
12
- from typing import List, Optional, Union
13
  from contextlib import asynccontextmanager
14
 
15
  from fastapi import FastAPI, HTTPException, Header, Request
@@ -26,32 +27,24 @@ LOG_DIR = "/tmp/logs"
26
  os.makedirs(LOG_DIR, exist_ok=True)
27
  LOG_FILE = os.path.join(LOG_DIR, "api.log")
28
 
29
- # Create formatters
30
  log_format = logging.Formatter(
31
  '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
32
  datefmt='%Y-%m-%d %H:%M:%S'
33
  )
34
 
35
- # File handler with rotation (10MB max, keep 5 backups)
36
  file_handler = RotatingFileHandler(
37
- LOG_FILE,
38
- maxBytes=10*1024*1024,
39
- backupCount=5,
40
- encoding='utf-8'
41
  )
42
  file_handler.setFormatter(log_format)
43
  file_handler.setLevel(logging.DEBUG)
44
 
45
- # Console handler
46
  console_handler = logging.StreamHandler()
47
  console_handler.setFormatter(log_format)
48
  console_handler.setLevel(logging.INFO)
49
 
50
- # Root logger
51
  logging.basicConfig(level=logging.DEBUG, handlers=[file_handler, console_handler])
52
  logger = logging.getLogger("anthropic-api")
53
 
54
- # Also capture uvicorn logs
55
  for uvicorn_logger in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
56
  uv_log = logging.getLogger(uvicorn_logger)
57
  uv_log.handlers = [file_handler, console_handler]
@@ -62,29 +55,21 @@ logger.info(f"Log file: {LOG_FILE}")
62
  logger.info("=" * 60)
63
 
64
  # ============== Configuration ==============
65
- MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" # Ultra-lightweight 135M model
66
- MAX_TOKENS_DEFAULT = 1024
67
  DEVICE = "cpu"
68
 
69
- # Global model and tokenizer
70
  model = None
71
  tokenizer = None
72
 
73
  @asynccontextmanager
74
  async def lifespan(app: FastAPI):
75
- """Load model on startup"""
76
  global model, tokenizer
77
  logger.info(f"Loading model: {MODEL_ID}")
78
-
79
  try:
80
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
81
  logger.info("Tokenizer loaded successfully")
82
-
83
  model = AutoModelForCausalLM.from_pretrained(
84
- MODEL_ID,
85
- torch_dtype=torch.float32,
86
- device_map=DEVICE,
87
- low_cpu_mem_usage=True
88
  )
89
  model.eval()
90
  logger.info("Model loaded successfully!")
@@ -92,21 +77,17 @@ async def lifespan(app: FastAPI):
92
  except Exception as e:
93
  logger.error(f"Failed to load model: {e}", exc_info=True)
94
  raise
95
-
96
  yield
97
-
98
- # Cleanup
99
  logger.info("Shutting down, cleaning up model...")
100
  del model, tokenizer
101
 
102
  app = FastAPI(
103
  title="Anthropic-Compatible API",
104
- description="Lightweight CPU-based API with Anthropic Messages API compatibility",
105
  version="1.0.0",
106
  lifespan=lifespan
107
  )
108
 
109
- # CORS middleware
110
  app.add_middleware(
111
  CORSMiddleware,
112
  allow_origins=["*"],
@@ -115,14 +96,11 @@ app.add_middleware(
115
  allow_headers=["*"],
116
  )
117
 
118
- # Request logging middleware
119
  @app.middleware("http")
120
  async def log_requests(request: Request, call_next):
121
  request_id = str(uuid.uuid4())[:8]
122
  start_time = time.time()
123
-
124
  logger.info(f"[{request_id}] {request.method} {request.url.path} - Started")
125
-
126
  try:
127
  response = await call_next(request)
128
  duration = (time.time() - start_time) * 1000
@@ -133,69 +111,193 @@ async def log_requests(request: Request, call_next):
133
  logger.error(f"[{request_id}] {request.method} {request.url.path} - Error: {e} ({duration:.2f}ms)")
134
  raise
135
 
136
- # ============== Pydantic Models (Anthropic-Compatible) ==============
137
 
138
- class ContentBlock(BaseModel):
139
- type: str = "text"
 
140
  text: str
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  class Message(BaseModel):
143
- role: str
144
  content: Union[str, List[ContentBlock]]
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  class MessageRequest(BaseModel):
 
147
  model: str
 
148
  messages: List[Message]
149
- max_tokens: int = MAX_TOKENS_DEFAULT
150
- temperature: Optional[float] = 0.7
151
- top_p: Optional[float] = 0.9
152
- top_k: Optional[int] = 50
153
- stream: Optional[bool] = False
154
- system: Optional[str] = None
155
- stop_sequences: Optional[List[str]] = None
156
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  class Usage(BaseModel):
158
  input_tokens: int
159
  output_tokens: int
 
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  class MessageResponse(BaseModel):
162
  id: str
163
- type: str = "message"
164
- role: str = "assistant"
165
- content: List[ContentBlock]
166
  model: str
167
- stop_reason: str = "end_turn"
168
  stop_sequence: Optional[str] = None
169
  usage: Usage
170
 
 
 
 
 
 
171
  class ErrorResponse(BaseModel):
172
- type: str = "error"
173
- error: dict
 
 
 
 
 
 
 
 
 
 
174
 
175
  # ============== Helper Functions ==============
176
 
177
- def format_messages(messages: List[Message], system: Optional[str] = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  """Format messages into a prompt string"""
179
  formatted_messages = []
180
 
181
- if system:
182
- formatted_messages.append({"role": "system", "content": system})
 
183
 
184
  for msg in messages:
185
- content = msg.content
186
- if isinstance(content, list):
187
- content = " ".join([block.text for block in content if block.type == "text"])
188
  formatted_messages.append({"role": msg.role, "content": content})
189
 
190
- # Use chat template if available
191
  if tokenizer.chat_template:
192
  return tokenizer.apply_chat_template(
193
- formatted_messages,
194
- tokenize=False,
195
- add_generation_prompt=True
196
  )
197
 
198
- # Fallback simple format
199
  prompt = ""
200
  for msg in formatted_messages:
201
  role = msg["role"].capitalize()
@@ -204,14 +306,12 @@ def format_messages(messages: List[Message], system: Optional[str] = None) -> st
204
  return prompt
205
 
206
  def generate_id() -> str:
207
- """Generate a unique message ID"""
208
  return f"msg_{uuid.uuid4().hex[:24]}"
209
 
210
  # ============== API Endpoints ==============
211
 
212
  @app.get("/")
213
  async def root():
214
- """Health check endpoint"""
215
  logger.debug("Root endpoint accessed")
216
  return {
217
  "status": "healthy",
@@ -223,24 +323,20 @@ async def root():
223
 
224
  @app.get("/v1/models")
225
  async def list_models():
226
- """List available models (Anthropic-compatible)"""
227
  logger.debug("Models list requested")
228
  return {
229
  "object": "list",
230
- "data": [
231
- {
232
- "id": "smollm2-135m",
233
- "object": "model",
234
- "created": int(time.time()),
235
- "owned_by": "huggingface",
236
- "display_name": "SmolLM2 135M Instruct"
237
- }
238
- ]
239
  }
240
 
241
  @app.get("/logs")
242
  async def get_logs(lines: int = 100):
243
- """Get recent log entries"""
244
  try:
245
  with open(LOG_FILE, 'r') as f:
246
  all_lines = f.readlines()
@@ -254,24 +350,22 @@ async def get_logs(lines: int = 100):
254
  except FileNotFoundError:
255
  return {"error": "Log file not found", "log_file": LOG_FILE}
256
 
257
- @app.post("/v1/messages")
258
  async def create_message(
259
  request: MessageRequest,
260
  x_api_key: Optional[str] = Header(None, alias="x-api-key"),
261
- anthropic_version: Optional[str] = Header(None, alias="anthropic-version")
 
262
  ):
263
- """
264
- Create a message (Anthropic Messages API compatible)
265
- """
266
  message_id = generate_id()
267
  logger.info(f"[{message_id}] Creating message - model: {request.model}, max_tokens: {request.max_tokens}, stream: {request.stream}")
 
268
 
269
  try:
270
- # Format the prompt
271
  prompt = format_messages(request.messages, request.system)
272
  logger.debug(f"[{message_id}] Prompt length: {len(prompt)} chars")
273
 
274
- # Tokenize
275
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
276
  input_token_count = inputs.input_ids.shape[1]
277
  logger.info(f"[{message_id}] Input tokens: {input_token_count}")
@@ -280,41 +374,72 @@ async def create_message(
280
  logger.info(f"[{message_id}] Starting streaming response")
281
  return await stream_response(request, inputs, input_token_count, message_id)
282
 
283
- # Generate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  gen_start = time.time()
285
  with torch.no_grad():
286
- outputs = model.generate(
287
- **inputs,
288
- max_new_tokens=request.max_tokens,
289
- temperature=request.temperature if request.temperature > 0 else 1.0,
290
- top_p=request.top_p,
291
- top_k=request.top_k,
292
- do_sample=request.temperature > 0,
293
- pad_token_id=tokenizer.eos_token_id,
294
- eos_token_id=tokenizer.eos_token_id,
295
- )
296
  gen_time = time.time() - gen_start
297
 
298
- # Decode only new tokens
299
  generated_tokens = outputs[0][input_token_count:]
300
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
301
  output_token_count = len(generated_tokens)
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  tokens_per_sec = output_token_count / gen_time if gen_time > 0 else 0
304
  logger.info(f"[{message_id}] Generated {output_token_count} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)")
305
 
306
- # Build response
307
  response = MessageResponse(
308
  id=message_id,
309
- content=[ContentBlock(type="text", text=generated_text.strip())],
310
  model=request.model,
311
- stop_reason="end_turn",
 
312
  usage=Usage(
313
  input_tokens=input_token_count,
314
  output_tokens=output_token_count
315
  )
316
  )
317
-
318
  return response
319
 
320
  except Exception as e:
@@ -322,10 +447,10 @@ async def create_message(
322
  raise HTTPException(status_code=500, detail=str(e))
323
 
324
  async def stream_response(request: MessageRequest, inputs, input_token_count: int, message_id: str):
325
- """Stream response using SSE (Server-Sent Events)"""
326
 
327
  async def generate():
328
- # Send message_start event
329
  start_event = {
330
  "type": "message_start",
331
  "message": {
@@ -341,7 +466,7 @@ async def stream_response(request: MessageRequest, inputs, input_token_count: in
341
  }
342
  yield f"event: message_start\ndata: {json.dumps(start_event)}\n\n"
343
 
344
- # Send content_block_start
345
  block_start = {
346
  "type": "content_block_start",
347
  "index": 0,
@@ -349,24 +474,29 @@ async def stream_response(request: MessageRequest, inputs, input_token_count: in
349
  }
350
  yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n"
351
 
352
- # Setup streamer
 
 
353
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
354
 
355
- generation_kwargs = {
356
  **inputs,
357
  "max_new_tokens": request.max_tokens,
358
- "temperature": request.temperature if request.temperature > 0 else 1.0,
359
- "top_p": request.top_p,
360
- "top_k": request.top_k,
361
- "do_sample": request.temperature > 0,
362
  "pad_token_id": tokenizer.eos_token_id,
363
  "eos_token_id": tokenizer.eos_token_id,
364
  "streamer": streamer,
365
  }
366
 
367
- # Run generation in a thread
 
 
 
 
 
 
368
  gen_start = time.time()
369
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
370
  thread.start()
371
 
372
  output_tokens = 0
@@ -385,19 +515,19 @@ async def stream_response(request: MessageRequest, inputs, input_token_count: in
385
  tokens_per_sec = output_tokens / gen_time if gen_time > 0 else 0
386
  logger.info(f"[{message_id}] Stream completed: {output_tokens} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)")
387
 
388
- # Send content_block_stop
389
- block_stop = {"type": "content_block_stop", "index": 0}
390
- yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n"
391
 
392
- # Send message_delta
 
393
  delta = {
394
  "type": "message_delta",
395
- "delta": {"stop_reason": "end_turn", "stop_sequence": None},
396
  "usage": {"output_tokens": output_tokens}
397
  }
398
  yield f"event: message_delta\ndata: {json.dumps(delta)}\n\n"
399
 
400
- # Send message_stop
401
  yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n"
402
 
403
  return StreamingResponse(
@@ -410,16 +540,14 @@ async def stream_response(request: MessageRequest, inputs, input_token_count: in
410
  }
411
  )
412
 
413
- # Token counting endpoint
414
- @app.post("/v1/messages/count_tokens")
415
- async def count_tokens(request: MessageRequest):
416
- """Count tokens for a message request"""
417
  prompt = format_messages(request.messages, request.system)
418
  tokens = tokenizer.encode(prompt)
419
  logger.debug(f"Token count request: {len(tokens)} tokens")
420
- return {"input_tokens": len(tokens)}
421
 
422
- # Health check
423
  @app.get("/health")
424
  async def health():
425
  return {"status": "ok", "model_loaded": model is not None, "log_file": LOG_FILE}
 
1
  """
2
  Anthropic-Compatible API Endpoint
3
  Lightweight CPU-based implementation for Hugging Face Spaces
4
+ Full Anthropic API parameter compatibility
5
  """
6
 
7
  import os
 
10
  import logging
11
  from datetime import datetime
12
  from logging.handlers import RotatingFileHandler
13
+ from typing import List, Optional, Union, Dict, Any, Literal
14
  from contextlib import asynccontextmanager
15
 
16
  from fastapi import FastAPI, HTTPException, Header, Request
 
27
  os.makedirs(LOG_DIR, exist_ok=True)
28
  LOG_FILE = os.path.join(LOG_DIR, "api.log")
29
 
 
30
  log_format = logging.Formatter(
31
  '%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
32
  datefmt='%Y-%m-%d %H:%M:%S'
33
  )
34
 
 
35
  file_handler = RotatingFileHandler(
36
+ LOG_FILE, maxBytes=10*1024*1024, backupCount=5, encoding='utf-8'
 
 
 
37
  )
38
  file_handler.setFormatter(log_format)
39
  file_handler.setLevel(logging.DEBUG)
40
 
 
41
  console_handler = logging.StreamHandler()
42
  console_handler.setFormatter(log_format)
43
  console_handler.setLevel(logging.INFO)
44
 
 
45
  logging.basicConfig(level=logging.DEBUG, handlers=[file_handler, console_handler])
46
  logger = logging.getLogger("anthropic-api")
47
 
 
48
  for uvicorn_logger in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
49
  uv_log = logging.getLogger(uvicorn_logger)
50
  uv_log.handlers = [file_handler, console_handler]
 
55
  logger.info("=" * 60)
56
 
57
  # ============== Configuration ==============
58
+ MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct"
 
59
  DEVICE = "cpu"
60
 
 
61
  model = None
62
  tokenizer = None
63
 
64
  @asynccontextmanager
65
  async def lifespan(app: FastAPI):
 
66
  global model, tokenizer
67
  logger.info(f"Loading model: {MODEL_ID}")
 
68
  try:
69
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
70
  logger.info("Tokenizer loaded successfully")
 
71
  model = AutoModelForCausalLM.from_pretrained(
72
+ MODEL_ID, torch_dtype=torch.float32, device_map=DEVICE, low_cpu_mem_usage=True
 
 
 
73
  )
74
  model.eval()
75
  logger.info("Model loaded successfully!")
 
77
  except Exception as e:
78
  logger.error(f"Failed to load model: {e}", exc_info=True)
79
  raise
 
80
  yield
 
 
81
  logger.info("Shutting down, cleaning up model...")
82
  del model, tokenizer
83
 
84
  app = FastAPI(
85
  title="Anthropic-Compatible API",
86
+ description="Lightweight CPU-based API with full Anthropic Messages API compatibility",
87
  version="1.0.0",
88
  lifespan=lifespan
89
  )
90
 
 
91
  app.add_middleware(
92
  CORSMiddleware,
93
  allow_origins=["*"],
 
96
  allow_headers=["*"],
97
  )
98
 
 
99
  @app.middleware("http")
100
  async def log_requests(request: Request, call_next):
101
  request_id = str(uuid.uuid4())[:8]
102
  start_time = time.time()
 
103
  logger.info(f"[{request_id}] {request.method} {request.url.path} - Started")
 
104
  try:
105
  response = await call_next(request)
106
  duration = (time.time() - start_time) * 1000
 
111
  logger.error(f"[{request_id}] {request.method} {request.url.path} - Error: {e} ({duration:.2f}ms)")
112
  raise
113
 
114
+ # ============== Anthropic-Compatible Pydantic Models ==============
115
 
116
+ # Content block types (matching Anthropic exactly)
117
+ class TextBlock(BaseModel):
118
+ type: Literal["text"] = "text"
119
  text: str
120
 
121
+ class ImageSource(BaseModel):
122
+ type: Literal["base64", "url"] = "base64"
123
+ media_type: Optional[str] = None
124
+ data: Optional[str] = None
125
+ url: Optional[str] = None
126
+
127
+ class ImageBlock(BaseModel):
128
+ type: Literal["image"] = "image"
129
+ source: ImageSource
130
+
131
+ class ToolUseBlock(BaseModel):
132
+ type: Literal["tool_use"] = "tool_use"
133
+ id: str
134
+ name: str
135
+ input: Dict[str, Any]
136
+
137
+ class ToolResultBlock(BaseModel):
138
+ type: Literal["tool_result"] = "tool_result"
139
+ tool_use_id: str
140
+ content: Optional[Union[str, List[TextBlock]]] = None
141
+ is_error: Optional[bool] = False
142
+
143
+ ContentBlock = Union[TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock]
144
+
145
+ # Message structure (matching Anthropic exactly)
146
  class Message(BaseModel):
147
+ role: Literal["user", "assistant"]
148
  content: Union[str, List[ContentBlock]]
149
 
150
+ # Tool definition (matching Anthropic exactly)
151
+ class ToolInputSchema(BaseModel):
152
+ type: Literal["object"] = "object"
153
+ properties: Optional[Dict[str, Any]] = None
154
+ required: Optional[List[str]] = None
155
+
156
+ class Tool(BaseModel):
157
+ name: str
158
+ description: Optional[str] = None
159
+ input_schema: ToolInputSchema
160
+
161
+ # Tool choice (matching Anthropic exactly)
162
+ class ToolChoiceAuto(BaseModel):
163
+ type: Literal["auto"] = "auto"
164
+ disable_parallel_tool_use: Optional[bool] = None
165
+
166
+ class ToolChoiceAny(BaseModel):
167
+ type: Literal["any"] = "any"
168
+ disable_parallel_tool_use: Optional[bool] = None
169
+
170
+ class ToolChoiceTool(BaseModel):
171
+ type: Literal["tool"] = "tool"
172
+ name: str
173
+ disable_parallel_tool_use: Optional[bool] = None
174
+
175
+ ToolChoice = Union[ToolChoiceAuto, ToolChoiceAny, ToolChoiceTool]
176
+
177
+ # Metadata (matching Anthropic exactly)
178
+ class Metadata(BaseModel):
179
+ user_id: Optional[str] = None
180
+
181
+ # System content (matching Anthropic exactly)
182
+ class SystemContent(BaseModel):
183
+ type: Literal["text"] = "text"
184
+ text: str
185
+ cache_control: Optional[Dict[str, str]] = None
186
+
187
+ # Main request model (matching Anthropic exactly)
188
  class MessageRequest(BaseModel):
189
+ # Required parameters
190
  model: str
191
+ max_tokens: int
192
  messages: List[Message]
 
 
 
 
 
 
 
193
 
194
+ # Optional parameters (matching Anthropic exactly)
195
+ metadata: Optional[Metadata] = None
196
+ stop_sequences: Optional[List[str]] = None
197
+ stream: Optional[bool] = False
198
+ system: Optional[Union[str, List[SystemContent]]] = None
199
+ temperature: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
200
+ tool_choice: Optional[ToolChoice] = None
201
+ tools: Optional[List[Tool]] = None
202
+ top_k: Optional[int] = Field(default=None, ge=0)
203
+ top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
204
+
205
+ # Usage model (matching Anthropic exactly)
206
  class Usage(BaseModel):
207
  input_tokens: int
208
  output_tokens: int
209
+ cache_creation_input_tokens: Optional[int] = None
210
+ cache_read_input_tokens: Optional[int] = None
211
 
212
+ # Response content block
213
+ class ResponseTextBlock(BaseModel):
214
+ type: Literal["text"] = "text"
215
+ text: str
216
+
217
+ class ResponseToolUseBlock(BaseModel):
218
+ type: Literal["tool_use"] = "tool_use"
219
+ id: str
220
+ name: str
221
+ input: Dict[str, Any]
222
+
223
+ ResponseContentBlock = Union[ResponseTextBlock, ResponseToolUseBlock]
224
+
225
+ # Main response model (matching Anthropic exactly)
226
  class MessageResponse(BaseModel):
227
  id: str
228
+ type: Literal["message"] = "message"
229
+ role: Literal["assistant"] = "assistant"
230
+ content: List[ResponseContentBlock]
231
  model: str
232
+ stop_reason: Optional[Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]] = None
233
  stop_sequence: Optional[str] = None
234
  usage: Usage
235
 
236
+ # Error response (matching Anthropic exactly)
237
+ class ErrorDetail(BaseModel):
238
+ type: str
239
+ message: str
240
+
241
  class ErrorResponse(BaseModel):
242
+ type: Literal["error"] = "error"
243
+ error: ErrorDetail
244
+
245
+ # Token count request/response (matching Anthropic exactly)
246
+ class TokenCountRequest(BaseModel):
247
+ model: str
248
+ messages: List[Message]
249
+ system: Optional[Union[str, List[SystemContent]]] = None
250
+ tools: Optional[List[Tool]] = None
251
+
252
+ class TokenCountResponse(BaseModel):
253
+ input_tokens: int
254
 
255
  # ============== Helper Functions ==============
256
 
257
+ def extract_text_content(content: Union[str, List[ContentBlock]]) -> str:
258
+ """Extract text from content (string or list of blocks)"""
259
+ if isinstance(content, str):
260
+ return content
261
+ texts = []
262
+ for block in content:
263
+ if isinstance(block, dict):
264
+ if block.get("type") == "text":
265
+ texts.append(block.get("text", ""))
266
+ elif hasattr(block, "type") and block.type == "text":
267
+ texts.append(block.text)
268
+ return " ".join(texts)
269
+
270
+ def extract_system_content(system: Optional[Union[str, List[SystemContent]]]) -> Optional[str]:
271
+ """Extract system prompt from string or list of system content blocks"""
272
+ if system is None:
273
+ return None
274
+ if isinstance(system, str):
275
+ return system
276
+ texts = []
277
+ for block in system:
278
+ if isinstance(block, dict):
279
+ texts.append(block.get("text", ""))
280
+ elif hasattr(block, "text"):
281
+ texts.append(block.text)
282
+ return " ".join(texts)
283
+
284
+ def format_messages(messages: List[Message], system: Optional[Union[str, List[SystemContent]]] = None) -> str:
285
  """Format messages into a prompt string"""
286
  formatted_messages = []
287
 
288
+ system_text = extract_system_content(system)
289
+ if system_text:
290
+ formatted_messages.append({"role": "system", "content": system_text})
291
 
292
  for msg in messages:
293
+ content = extract_text_content(msg.content)
 
 
294
  formatted_messages.append({"role": msg.role, "content": content})
295
 
 
296
  if tokenizer.chat_template:
297
  return tokenizer.apply_chat_template(
298
+ formatted_messages, tokenize=False, add_generation_prompt=True
 
 
299
  )
300
 
 
301
  prompt = ""
302
  for msg in formatted_messages:
303
  role = msg["role"].capitalize()
 
306
  return prompt
307
 
308
  def generate_id() -> str:
 
309
  return f"msg_{uuid.uuid4().hex[:24]}"
310
 
311
  # ============== API Endpoints ==============
312
 
313
  @app.get("/")
314
  async def root():
 
315
  logger.debug("Root endpoint accessed")
316
  return {
317
  "status": "healthy",
 
323
 
324
  @app.get("/v1/models")
325
  async def list_models():
 
326
  logger.debug("Models list requested")
327
  return {
328
  "object": "list",
329
+ "data": [{
330
+ "id": "smollm2-135m",
331
+ "object": "model",
332
+ "created": int(time.time()),
333
+ "owned_by": "huggingface",
334
+ "display_name": "SmolLM2 135M Instruct"
335
+ }]
 
 
336
  }
337
 
338
  @app.get("/logs")
339
  async def get_logs(lines: int = 100):
 
340
  try:
341
  with open(LOG_FILE, 'r') as f:
342
  all_lines = f.readlines()
 
350
  except FileNotFoundError:
351
  return {"error": "Log file not found", "log_file": LOG_FILE}
352
 
353
+ @app.post("/v1/messages", response_model=MessageResponse)
354
  async def create_message(
355
  request: MessageRequest,
356
  x_api_key: Optional[str] = Header(None, alias="x-api-key"),
357
+ anthropic_version: Optional[str] = Header(None, alias="anthropic-version"),
358
+ anthropic_beta: Optional[str] = Header(None, alias="anthropic-beta")
359
  ):
360
+ """Create a message (Anthropic Messages API compatible)"""
 
 
361
  message_id = generate_id()
362
  logger.info(f"[{message_id}] Creating message - model: {request.model}, max_tokens: {request.max_tokens}, stream: {request.stream}")
363
+ logger.debug(f"[{message_id}] Request params - temp: {request.temperature}, top_p: {request.top_p}, top_k: {request.top_k}")
364
 
365
  try:
 
366
  prompt = format_messages(request.messages, request.system)
367
  logger.debug(f"[{message_id}] Prompt length: {len(prompt)} chars")
368
 
 
369
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
370
  input_token_count = inputs.input_ids.shape[1]
371
  logger.info(f"[{message_id}] Input tokens: {input_token_count}")
 
374
  logger.info(f"[{message_id}] Starting streaming response")
375
  return await stream_response(request, inputs, input_token_count, message_id)
376
 
377
+ # Build generation kwargs matching Anthropic params
378
+ gen_kwargs = {
379
+ "max_new_tokens": request.max_tokens,
380
+ "do_sample": request.temperature > 0 if request.temperature else False,
381
+ "pad_token_id": tokenizer.eos_token_id,
382
+ "eos_token_id": tokenizer.eos_token_id,
383
+ }
384
+
385
+ # Temperature (Anthropic default: 1.0)
386
+ if request.temperature is not None and request.temperature > 0:
387
+ gen_kwargs["temperature"] = request.temperature
388
+
389
+ # Top-p (nucleus sampling)
390
+ if request.top_p is not None:
391
+ gen_kwargs["top_p"] = request.top_p
392
+
393
+ # Top-k sampling
394
+ if request.top_k is not None:
395
+ gen_kwargs["top_k"] = request.top_k
396
+
397
+ # Stop sequences
398
+ if request.stop_sequences:
399
+ stop_token_ids = []
400
+ for seq in request.stop_sequences:
401
+ tokens = tokenizer.encode(seq, add_special_tokens=False)
402
+ if tokens:
403
+ stop_token_ids.extend(tokens)
404
+ if stop_token_ids:
405
+ gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + stop_token_ids))
406
+
407
  gen_start = time.time()
408
  with torch.no_grad():
409
+ outputs = model.generate(**inputs, **gen_kwargs)
 
 
 
 
 
 
 
 
 
410
  gen_time = time.time() - gen_start
411
 
 
412
  generated_tokens = outputs[0][input_token_count:]
413
  generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
414
  output_token_count = len(generated_tokens)
415
 
416
+ # Determine stop reason
417
+ stop_reason = "end_turn"
418
+ stop_sequence = None
419
+ if output_token_count >= request.max_tokens:
420
+ stop_reason = "max_tokens"
421
+ elif request.stop_sequences:
422
+ for seq in request.stop_sequences:
423
+ if seq in generated_text:
424
+ stop_reason = "stop_sequence"
425
+ stop_sequence = seq
426
+ generated_text = generated_text.split(seq)[0]
427
+ break
428
+
429
  tokens_per_sec = output_token_count / gen_time if gen_time > 0 else 0
430
  logger.info(f"[{message_id}] Generated {output_token_count} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)")
431
 
 
432
  response = MessageResponse(
433
  id=message_id,
434
+ content=[ResponseTextBlock(type="text", text=generated_text.strip())],
435
  model=request.model,
436
+ stop_reason=stop_reason,
437
+ stop_sequence=stop_sequence,
438
  usage=Usage(
439
  input_tokens=input_token_count,
440
  output_tokens=output_token_count
441
  )
442
  )
 
443
  return response
444
 
445
  except Exception as e:
 
447
  raise HTTPException(status_code=500, detail=str(e))
448
 
449
  async def stream_response(request: MessageRequest, inputs, input_token_count: int, message_id: str):
450
+ """Stream response using SSE (Server-Sent Events) - Anthropic format"""
451
 
452
  async def generate():
453
+ # message_start event
454
  start_event = {
455
  "type": "message_start",
456
  "message": {
 
466
  }
467
  yield f"event: message_start\ndata: {json.dumps(start_event)}\n\n"
468
 
469
+ # content_block_start event
470
  block_start = {
471
  "type": "content_block_start",
472
  "index": 0,
 
474
  }
475
  yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n"
476
 
477
+ # ping event (Anthropic sends these)
478
+ yield f"event: ping\ndata: {json.dumps({'type': 'ping'})}\n\n"
479
+
480
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
481
 
482
+ gen_kwargs = {
483
  **inputs,
484
  "max_new_tokens": request.max_tokens,
485
+ "do_sample": request.temperature > 0 if request.temperature else False,
 
 
 
486
  "pad_token_id": tokenizer.eos_token_id,
487
  "eos_token_id": tokenizer.eos_token_id,
488
  "streamer": streamer,
489
  }
490
 
491
+ if request.temperature is not None and request.temperature > 0:
492
+ gen_kwargs["temperature"] = request.temperature
493
+ if request.top_p is not None:
494
+ gen_kwargs["top_p"] = request.top_p
495
+ if request.top_k is not None:
496
+ gen_kwargs["top_k"] = request.top_k
497
+
498
  gen_start = time.time()
499
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
500
  thread.start()
501
 
502
  output_tokens = 0
 
515
  tokens_per_sec = output_tokens / gen_time if gen_time > 0 else 0
516
  logger.info(f"[{message_id}] Stream completed: {output_tokens} tokens in {gen_time:.2f}s ({tokens_per_sec:.1f} tok/s)")
517
 
518
+ # content_block_stop event
519
+ yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n"
 
520
 
521
+ # message_delta event
522
+ stop_reason = "max_tokens" if output_tokens >= request.max_tokens else "end_turn"
523
  delta = {
524
  "type": "message_delta",
525
+ "delta": {"stop_reason": stop_reason, "stop_sequence": None},
526
  "usage": {"output_tokens": output_tokens}
527
  }
528
  yield f"event: message_delta\ndata: {json.dumps(delta)}\n\n"
529
 
530
+ # message_stop event
531
  yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n"
532
 
533
  return StreamingResponse(
 
540
  }
541
  )
542
 
543
+ @app.post("/v1/messages/count_tokens", response_model=TokenCountResponse)
544
+ async def count_tokens(request: TokenCountRequest):
545
+ """Count tokens for a message request (Anthropic compatible)"""
 
546
  prompt = format_messages(request.messages, request.system)
547
  tokens = tokenizer.encode(prompt)
548
  logger.debug(f"Token count request: {len(tokens)} tokens")
549
+ return TokenCountResponse(input_tokens=len(tokens))
550
 
 
551
  @app.get("/health")
552
  async def health():
553
  return {"status": "ok", "model_loaded": model is not None, "log_file": LOG_FILE}