File size: 4,201 Bytes
84853d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
#!/usr/bin/env python3
"""
SmolLM3-3B FastAPI Application for LangGraph Conductor
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import uvicorn
import os
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="SmolLM3-3B LangGraph Conductor API",
description="FastAPI backend for SmolLM3-3B model serving LangGraph conductor",
version="1.0.0"
)
# Global model and tokenizer
model = None
tokenizer = None
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.7
enable_thinking: Optional[bool] = False
class ChatResponse(BaseModel):
choices: List[Dict[str, Any]]
usage: Dict[str, int]
@app.on_event("startup")
async def load_model():
"""Load the SmolLM3-3B model on startup"""
global model, tokenizer
try:
logger.info("Loading SmolLM3-3B model...")
model_id = "HuggingFaceTB/SmolLM3-3B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": model is not None,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
@app.get("/models")
async def list_models():
"""List available models"""
return {
"data": [
{
"id": "SmolLM3-3B",
"object": "model",
"created": 1699000000,
"owned_by": "huggingface"
}
]
}
@app.post("/chat", response_model=ChatResponse)
async def chat_completion(request: ChatRequest):
"""Chat completion endpoint"""
if model is None or tokenizer is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Format messages for the model
conversation = ""
for msg in request.messages:
if msg.role == "user":
conversation += f"User: {msg.content}\n"
elif msg.role == "assistant":
conversation += f"Assistant: {msg.content}\n"
conversation += "Assistant: "
# Tokenize and generate
inputs = tokenizer(conversation, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode response
response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return ChatResponse(
choices=[
{
"index": 0,
"message": {
"role": "assistant",
"content": response_text.strip()
},
"finish_reason": "stop"
}
],
usage={
"prompt_tokens": inputs.input_ids.shape[1],
"completion_tokens": len(outputs[0]) - inputs.input_ids.shape[1],
"total_tokens": len(outputs[0])
}
)
except Exception as e:
logger.error(f"Error in chat completion: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|