|
|
""" |
|
|
Anthropic-Compatible API Endpoint |
|
|
Lightweight CPU-based implementation for Hugging Face Spaces |
|
|
""" |
|
|
|
|
|
import os |
|
|
import time |
|
|
import uuid |
|
|
from typing import List, Optional, Union |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Header, Request |
|
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
from threading import Thread |
|
|
import json |
|
|
|
|
|
|
|
|
MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" |
|
|
MAX_TOKENS_DEFAULT = 1024 |
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Load model on startup""" |
|
|
global model, tokenizer |
|
|
print(f"Loading model: {MODEL_ID}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
device_map=DEVICE, |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
model.eval() |
|
|
print("Model loaded successfully!") |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
del model, tokenizer |
|
|
|
|
|
app = FastAPI( |
|
|
title="Anthropic-Compatible API", |
|
|
description="Lightweight CPU-based API with Anthropic Messages API compatibility", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class ContentBlock(BaseModel): |
|
|
type: str = "text" |
|
|
text: str |
|
|
|
|
|
class Message(BaseModel): |
|
|
role: str |
|
|
content: Union[str, List[ContentBlock]] |
|
|
|
|
|
class MessageRequest(BaseModel): |
|
|
model: str |
|
|
messages: List[Message] |
|
|
max_tokens: int = MAX_TOKENS_DEFAULT |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 0.9 |
|
|
top_k: Optional[int] = 50 |
|
|
stream: Optional[bool] = False |
|
|
system: Optional[str] = None |
|
|
stop_sequences: Optional[List[str]] = None |
|
|
|
|
|
class Usage(BaseModel): |
|
|
input_tokens: int |
|
|
output_tokens: int |
|
|
|
|
|
class MessageResponse(BaseModel): |
|
|
id: str |
|
|
type: str = "message" |
|
|
role: str = "assistant" |
|
|
content: List[ContentBlock] |
|
|
model: str |
|
|
stop_reason: str = "end_turn" |
|
|
stop_sequence: Optional[str] = None |
|
|
usage: Usage |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
|
type: str = "error" |
|
|
error: dict |
|
|
|
|
|
|
|
|
|
|
|
def format_messages(messages: List[Message], system: Optional[str] = None) -> str: |
|
|
"""Format messages into a prompt string""" |
|
|
formatted_messages = [] |
|
|
|
|
|
if system: |
|
|
formatted_messages.append({"role": "system", "content": system}) |
|
|
|
|
|
for msg in messages: |
|
|
content = msg.content |
|
|
if isinstance(content, list): |
|
|
content = " ".join([block.text for block in content if block.type == "text"]) |
|
|
formatted_messages.append({"role": msg.role, "content": content}) |
|
|
|
|
|
|
|
|
if tokenizer.chat_template: |
|
|
return tokenizer.apply_chat_template( |
|
|
formatted_messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
prompt = "" |
|
|
for msg in formatted_messages: |
|
|
role = msg["role"].capitalize() |
|
|
prompt += f"{role}: {msg['content']}\n" |
|
|
prompt += "Assistant: " |
|
|
return prompt |
|
|
|
|
|
def generate_id() -> str: |
|
|
"""Generate a unique message ID""" |
|
|
return f"msg_{uuid.uuid4().hex[:24]}" |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": MODEL_ID, |
|
|
"api_version": "2023-06-01", |
|
|
"compatibility": "anthropic-messages-api" |
|
|
} |
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
"""List available models (Anthropic-compatible)""" |
|
|
return { |
|
|
"object": "list", |
|
|
"data": [ |
|
|
{ |
|
|
"id": "smollm2-135m", |
|
|
"object": "model", |
|
|
"created": int(time.time()), |
|
|
"owned_by": "huggingface", |
|
|
"display_name": "SmolLM2 135M Instruct" |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
@app.post("/v1/messages") |
|
|
async def create_message( |
|
|
request: MessageRequest, |
|
|
x_api_key: Optional[str] = Header(None, alias="x-api-key"), |
|
|
anthropic_version: Optional[str] = Header(None, alias="anthropic-version") |
|
|
): |
|
|
""" |
|
|
Create a message (Anthropic Messages API compatible) |
|
|
""" |
|
|
try: |
|
|
|
|
|
prompt = format_messages(request.messages, request.system) |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) |
|
|
input_token_count = inputs.input_ids.shape[1] |
|
|
|
|
|
if request.stream: |
|
|
return await stream_response(request, inputs, input_token_count) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=request.max_tokens, |
|
|
temperature=request.temperature if request.temperature > 0 else 1.0, |
|
|
top_p=request.top_p, |
|
|
top_k=request.top_k, |
|
|
do_sample=request.temperature > 0, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
generated_tokens = outputs[0][input_token_count:] |
|
|
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
output_token_count = len(generated_tokens) |
|
|
|
|
|
|
|
|
response = MessageResponse( |
|
|
id=generate_id(), |
|
|
content=[ContentBlock(type="text", text=generated_text.strip())], |
|
|
model=request.model, |
|
|
stop_reason="end_turn", |
|
|
usage=Usage( |
|
|
input_tokens=input_token_count, |
|
|
output_tokens=output_token_count |
|
|
) |
|
|
) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
async def stream_response(request: MessageRequest, inputs, input_token_count: int): |
|
|
"""Stream response using SSE (Server-Sent Events)""" |
|
|
|
|
|
message_id = generate_id() |
|
|
|
|
|
async def generate(): |
|
|
|
|
|
start_event = { |
|
|
"type": "message_start", |
|
|
"message": { |
|
|
"id": message_id, |
|
|
"type": "message", |
|
|
"role": "assistant", |
|
|
"content": [], |
|
|
"model": request.model, |
|
|
"stop_reason": None, |
|
|
"stop_sequence": None, |
|
|
"usage": {"input_tokens": input_token_count, "output_tokens": 0} |
|
|
} |
|
|
} |
|
|
yield f"event: message_start\ndata: {json.dumps(start_event)}\n\n" |
|
|
|
|
|
|
|
|
block_start = { |
|
|
"type": "content_block_start", |
|
|
"index": 0, |
|
|
"content_block": {"type": "text", "text": ""} |
|
|
} |
|
|
yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n" |
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = { |
|
|
**inputs, |
|
|
"max_new_tokens": request.max_tokens, |
|
|
"temperature": request.temperature if request.temperature > 0 else 1.0, |
|
|
"top_p": request.top_p, |
|
|
"top_k": request.top_k, |
|
|
"do_sample": request.temperature > 0, |
|
|
"pad_token_id": tokenizer.eos_token_id, |
|
|
"eos_token_id": tokenizer.eos_token_id, |
|
|
"streamer": streamer, |
|
|
} |
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
output_tokens = 0 |
|
|
for text in streamer: |
|
|
if text: |
|
|
output_tokens += len(tokenizer.encode(text, add_special_tokens=False)) |
|
|
delta_event = { |
|
|
"type": "content_block_delta", |
|
|
"index": 0, |
|
|
"delta": {"type": "text_delta", "text": text} |
|
|
} |
|
|
yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" |
|
|
|
|
|
thread.join() |
|
|
|
|
|
|
|
|
block_stop = {"type": "content_block_stop", "index": 0} |
|
|
yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n" |
|
|
|
|
|
|
|
|
delta = { |
|
|
"type": "message_delta", |
|
|
"delta": {"stop_reason": "end_turn", "stop_sequence": None}, |
|
|
"usage": {"output_tokens": output_tokens} |
|
|
} |
|
|
yield f"event: message_delta\ndata: {json.dumps(delta)}\n\n" |
|
|
|
|
|
|
|
|
yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n" |
|
|
|
|
|
return StreamingResponse( |
|
|
generate(), |
|
|
media_type="text/event-stream", |
|
|
headers={ |
|
|
"Cache-Control": "no-cache", |
|
|
"Connection": "keep-alive", |
|
|
"X-Accel-Buffering": "no" |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/v1/messages/count_tokens") |
|
|
async def count_tokens(request: MessageRequest): |
|
|
"""Count tokens for a message request""" |
|
|
prompt = format_messages(request.messages, request.system) |
|
|
tokens = tokenizer.encode(prompt) |
|
|
return {"input_tokens": len(tokens)} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return {"status": "ok", "model_loaded": model is not None} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|