|
|
|
|
|
""" |
|
|
SmolLM3-3B FastAPI Application for LangGraph Conductor |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from typing import List, Optional, Dict, Any |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import uvicorn |
|
|
import os |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI( |
|
|
title="SmolLM3-3B LangGraph Conductor API", |
|
|
description="FastAPI backend for SmolLM3-3B model serving LangGraph conductor", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: List[ChatMessage] |
|
|
max_tokens: Optional[int] = 512 |
|
|
temperature: Optional[float] = 0.7 |
|
|
enable_thinking: Optional[bool] = False |
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
choices: List[Dict[str, Any]] |
|
|
usage: Dict[str, int] |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
"""Load the SmolLM3-3B model on startup""" |
|
|
global model, tokenizer |
|
|
|
|
|
try: |
|
|
logger.info("Loading SmolLM3-3B model...") |
|
|
model_id = "HuggingFaceTB/SmolLM3-3B" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto" if torch.cuda.is_available() else None |
|
|
) |
|
|
|
|
|
logger.info("Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise e |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": model is not None, |
|
|
"device": "cuda" if torch.cuda.is_available() else "cpu" |
|
|
} |
|
|
|
|
|
@app.get("/models") |
|
|
async def list_models(): |
|
|
"""List available models""" |
|
|
return { |
|
|
"data": [ |
|
|
{ |
|
|
"id": "SmolLM3-3B", |
|
|
"object": "model", |
|
|
"created": 1699000000, |
|
|
"owned_by": "huggingface" |
|
|
} |
|
|
] |
|
|
} |
|
|
|
|
|
@app.post("/chat", response_model=ChatResponse) |
|
|
async def chat_completion(request: ChatRequest): |
|
|
"""Chat completion endpoint""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
|
|
|
conversation = "" |
|
|
for msg in request.messages: |
|
|
if msg.role == "user": |
|
|
conversation += f"User: {msg.content}\n" |
|
|
elif msg.role == "assistant": |
|
|
conversation += f"Assistant: {msg.content}\n" |
|
|
|
|
|
conversation += "Assistant: " |
|
|
|
|
|
|
|
|
inputs = tokenizer(conversation, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
inputs.input_ids, |
|
|
max_new_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) |
|
|
|
|
|
return ChatResponse( |
|
|
choices=[ |
|
|
{ |
|
|
"index": 0, |
|
|
"message": { |
|
|
"role": "assistant", |
|
|
"content": response_text.strip() |
|
|
}, |
|
|
"finish_reason": "stop" |
|
|
} |
|
|
], |
|
|
usage={ |
|
|
"prompt_tokens": inputs.input_ids.shape[1], |
|
|
"completion_tokens": len(outputs[0]) - inputs.input_ids.shape[1], |
|
|
"total_tokens": len(outputs[0]) |
|
|
} |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in chat completion: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|