""" Anthropic-Compatible API Endpoint Lightweight CPU-based implementation for Hugging Face Spaces """ import os import time import uuid from typing import List, Optional, Union from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException, Header, Request from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import json # ============== Configuration ============== MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" # Ultra-lightweight 135M model MAX_TOKENS_DEFAULT = 1024 DEVICE = "cpu" # Global model and tokenizer model = None tokenizer = None @asynccontextmanager async def lifespan(app: FastAPI): """Load model on startup""" global model, tokenizer print(f"Loading model: {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map=DEVICE, low_cpu_mem_usage=True ) model.eval() print("Model loaded successfully!") yield # Cleanup del model, tokenizer app = FastAPI( title="Anthropic-Compatible API", description="Lightweight CPU-based API with Anthropic Messages API compatibility", version="1.0.0", lifespan=lifespan ) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============== Pydantic Models (Anthropic-Compatible) ============== class ContentBlock(BaseModel): type: str = "text" text: str class Message(BaseModel): role: str content: Union[str, List[ContentBlock]] class MessageRequest(BaseModel): model: str messages: List[Message] max_tokens: int = MAX_TOKENS_DEFAULT temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 top_k: Optional[int] = 50 stream: Optional[bool] = False system: Optional[str] = None stop_sequences: Optional[List[str]] = None class Usage(BaseModel): input_tokens: int output_tokens: int class MessageResponse(BaseModel): id: str type: str = "message" role: str = "assistant" content: List[ContentBlock] model: str stop_reason: str = "end_turn" stop_sequence: Optional[str] = None usage: Usage class ErrorResponse(BaseModel): type: str = "error" error: dict # ============== Helper Functions ============== def format_messages(messages: List[Message], system: Optional[str] = None) -> str: """Format messages into a prompt string""" formatted_messages = [] if system: formatted_messages.append({"role": "system", "content": system}) for msg in messages: content = msg.content if isinstance(content, list): content = " ".join([block.text for block in content if block.type == "text"]) formatted_messages.append({"role": msg.role, "content": content}) # Use chat template if available if tokenizer.chat_template: return tokenizer.apply_chat_template( formatted_messages, tokenize=False, add_generation_prompt=True ) # Fallback simple format prompt = "" for msg in formatted_messages: role = msg["role"].capitalize() prompt += f"{role}: {msg['content']}\n" prompt += "Assistant: " return prompt def generate_id() -> str: """Generate a unique message ID""" return f"msg_{uuid.uuid4().hex[:24]}" # ============== API Endpoints ============== @app.get("/") async def root(): """Health check endpoint""" return { "status": "healthy", "model": MODEL_ID, "api_version": "2023-06-01", "compatibility": "anthropic-messages-api" } @app.get("/v1/models") async def list_models(): """List available models (Anthropic-compatible)""" return { "object": "list", "data": [ { "id": "smollm2-135m", "object": "model", "created": int(time.time()), "owned_by": "huggingface", "display_name": "SmolLM2 135M Instruct" } ] } @app.post("/v1/messages") async def create_message( request: MessageRequest, x_api_key: Optional[str] = Header(None, alias="x-api-key"), anthropic_version: Optional[str] = Header(None, alias="anthropic-version") ): """ Create a message (Anthropic Messages API compatible) """ try: # Format the prompt prompt = format_messages(request.messages, request.system) # Tokenize inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) input_token_count = inputs.input_ids.shape[1] if request.stream: return await stream_response(request, inputs, input_token_count) # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_tokens, temperature=request.temperature if request.temperature > 0 else 1.0, top_p=request.top_p, top_k=request.top_k, do_sample=request.temperature > 0, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode only new tokens generated_tokens = outputs[0][input_token_count:] generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) output_token_count = len(generated_tokens) # Build response response = MessageResponse( id=generate_id(), content=[ContentBlock(type="text", text=generated_text.strip())], model=request.model, stop_reason="end_turn", usage=Usage( input_tokens=input_token_count, output_tokens=output_token_count ) ) return response except Exception as e: raise HTTPException(status_code=500, detail=str(e)) async def stream_response(request: MessageRequest, inputs, input_token_count: int): """Stream response using SSE (Server-Sent Events)""" message_id = generate_id() async def generate(): # Send message_start event start_event = { "type": "message_start", "message": { "id": message_id, "type": "message", "role": "assistant", "content": [], "model": request.model, "stop_reason": None, "stop_sequence": None, "usage": {"input_tokens": input_token_count, "output_tokens": 0} } } yield f"event: message_start\ndata: {json.dumps(start_event)}\n\n" # Send content_block_start block_start = { "type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""} } yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" # Setup streamer streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "max_new_tokens": request.max_tokens, "temperature": request.temperature if request.temperature > 0 else 1.0, "top_p": request.top_p, "top_k": request.top_k, "do_sample": request.temperature > 0, "pad_token_id": tokenizer.eos_token_id, "eos_token_id": tokenizer.eos_token_id, "streamer": streamer, } # Run generation in a thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() output_tokens = 0 for text in streamer: if text: output_tokens += len(tokenizer.encode(text, add_special_tokens=False)) delta_event = { "type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": text} } yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" thread.join() # Send content_block_stop block_stop = {"type": "content_block_stop", "index": 0} yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n" # Send message_delta delta = { "type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence": None}, "usage": {"output_tokens": output_tokens} } yield f"event: message_delta\ndata: {json.dumps(delta)}\n\n" # Send message_stop yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no" } ) # Token counting endpoint @app.post("/v1/messages/count_tokens") async def count_tokens(request: MessageRequest): """Count tokens for a message request""" prompt = format_messages(request.messages, request.system) tokens = tokenizer.encode(prompt) return {"input_tokens": len(tokens)} # Health check @app.get("/health") async def health(): return {"status": "ok", "model_loaded": model is not None} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)