|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Dict, Any, Optional, Union |
|
|
import torch |
|
|
import time |
|
|
import logging |
|
|
import asyncio |
|
|
from datetime import datetime |
|
|
import json |
|
|
from contextlib import asynccontextmanager |
|
|
import uvicorn |
|
|
import psutil |
|
|
import GPUtil |
|
|
from ..configs.config import Config, get_balanced_config |
|
|
from ..architecture.model import create_compact_model, CompactAIModel |
|
|
import os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
model: Optional[CompactAIModel] = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
"""Application lifespan manager.""" |
|
|
global model |
|
|
|
|
|
|
|
|
logger.info("Loading Compact AI Model...") |
|
|
try: |
|
|
model_size = os.getenv("MODEL_SIZE", "small") |
|
|
model = create_compact_model(model_size) |
|
|
|
|
|
|
|
|
checkpoint_path = os.getenv("MODEL_CHECKPOINT") |
|
|
if checkpoint_path and os.path.exists(checkpoint_path): |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
model.load_state_dict(checkpoint) |
|
|
logger.info(f"Loaded model checkpoint from {checkpoint_path}") |
|
|
|
|
|
model.eval() |
|
|
if torch.cuda.is_available(): |
|
|
model = model.cuda() |
|
|
|
|
|
logger.info("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
model = None |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
logger.info("Shutting down...") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Compact AI Model API", |
|
|
description="API for the compact AI model with interleaved thinking", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan, |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str = Field(..., description="Role of the message (user/assistant/system)") |
|
|
content: str = Field(..., description="Content of the message") |
|
|
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: str = Field(default="compact-ai-v1", description="Model name") |
|
|
messages: List[ChatMessage] = Field(..., description="List of messages") |
|
|
max_tokens: Optional[int] = Field(default=100, description="Maximum tokens to generate") |
|
|
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") |
|
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0, description="Top-p sampling") |
|
|
reasoning_depth: Optional[Union[str, int]] = Field(default="adaptive", description="Reasoning depth") |
|
|
early_stop_threshold: Optional[float] = Field(default=0.85, description="Early stop threshold") |
|
|
thinking_visualization: Optional[bool] = Field(default=False, description="Include thinking visualization") |
|
|
|
|
|
|
|
|
class CompletionRequest(BaseModel): |
|
|
model: str = Field(default="compact-ai-v1", description="Model name") |
|
|
prompt: str = Field(..., description="Input prompt") |
|
|
max_tokens: Optional[int] = Field(default=50, description="Maximum tokens to generate") |
|
|
temperature: Optional[float] = Field(default=0.8, ge=0.0, le=2.0, description="Sampling temperature") |
|
|
reasoning_tokens: Optional[int] = Field(default=100, description="Maximum reasoning tokens") |
|
|
|
|
|
|
|
|
class AnthropicMessageRequest(BaseModel): |
|
|
model: str = Field(default="compact-ai-v1", description="Model name") |
|
|
messages: List[ChatMessage] = Field(..., description="List of messages") |
|
|
max_tokens: int = Field(default=1024, description="Maximum tokens to generate") |
|
|
system: Optional[str] = Field(default=None, description="System message") |
|
|
thinking_config: Optional[Dict[str, Any]] = Field(default=None, description="Thinking configuration") |
|
|
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel): |
|
|
index: int |
|
|
message: ChatMessage |
|
|
finish_reason: str |
|
|
thinking_trace: Optional[Dict[str, Any]] = None |
|
|
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
|
id: str |
|
|
object: str = "chat.completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[ChatCompletionChoice] |
|
|
usage: Dict[str, int] |
|
|
|
|
|
|
|
|
class CompletionChoice(BaseModel): |
|
|
text: str |
|
|
index: int |
|
|
finish_reason: str |
|
|
thinking_tokens: Optional[int] = None |
|
|
|
|
|
|
|
|
class CompletionResponse(BaseModel): |
|
|
id: str |
|
|
object: str = "text_completion" |
|
|
created: int |
|
|
model: str |
|
|
choices: List[CompletionChoice] |
|
|
usage: Dict[str, int] |
|
|
|
|
|
|
|
|
class AnthropicMessageResponse(BaseModel): |
|
|
id: str |
|
|
type: str = "message" |
|
|
role: str = "assistant" |
|
|
content: List[Dict[str, Any]] |
|
|
model: str |
|
|
usage: Dict[str, int] |
|
|
|
|
|
|
|
|
class ModelInfo(BaseModel): |
|
|
id: str |
|
|
object: str = "model" |
|
|
created: int |
|
|
owned_by: str = "compact-ai" |
|
|
|
|
|
|
|
|
class ModelListResponse(BaseModel): |
|
|
object: str = "list" |
|
|
data: List[ModelInfo] |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
model_loaded: bool |
|
|
gpu_available: bool |
|
|
memory_usage: Dict[str, Any] |
|
|
uptime: str |
|
|
|
|
|
|
|
|
|
|
|
class SimpleTokenizer: |
|
|
def __init__(self, vocab_size=32000): |
|
|
self.vocab_size = vocab_size |
|
|
self.pad_token_id = 0 |
|
|
self.eos_token_id = 1 |
|
|
self.bos_token_id = 2 |
|
|
|
|
|
def encode(self, text: str, max_length=None, truncation=True, padding=False): |
|
|
|
|
|
tokens = text.split() |
|
|
token_ids = [hash(word) % (self.vocab_size - 100) + 100 for word in tokens] |
|
|
|
|
|
if max_length and len(token_ids) > max_length: |
|
|
token_ids = token_ids[:max_length] |
|
|
|
|
|
if padding and max_length: |
|
|
token_ids += [self.pad_token_id] * (max_length - len(token_ids)) |
|
|
|
|
|
return token_ids |
|
|
|
|
|
def decode(self, token_ids: List[int]): |
|
|
|
|
|
return " ".join([f"<token_{tid}>" for tid in token_ids]) |
|
|
|
|
|
|
|
|
tokenizer = SimpleTokenizer() |
|
|
|
|
|
|
|
|
def generate_text( |
|
|
prompt: str, |
|
|
max_tokens: int = 50, |
|
|
temperature: float = 0.8, |
|
|
reasoning_depth: Union[str, int] = "adaptive", |
|
|
early_stop_threshold: float = 0.85, |
|
|
use_thinking: bool = True, |
|
|
) -> Dict[str, Any]: |
|
|
"""Generate text using the model.""" |
|
|
if model is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
|
|
|
input_ids = tokenizer.encode(prompt, max_length=512, truncation=True) |
|
|
input_tensor = torch.tensor([input_ids], dtype=torch.long) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
input_tensor = input_tensor.cuda() |
|
|
|
|
|
|
|
|
if isinstance(reasoning_depth, str): |
|
|
if reasoning_depth == "adaptive": |
|
|
max_reasoning_depth = None |
|
|
elif reasoning_depth == "simple": |
|
|
max_reasoning_depth = 1 |
|
|
elif reasoning_depth == "complex": |
|
|
max_reasoning_depth = 4 |
|
|
else: |
|
|
max_reasoning_depth = 2 |
|
|
else: |
|
|
max_reasoning_depth = reasoning_depth |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model( |
|
|
input_tensor, |
|
|
use_thinking=use_thinking, |
|
|
max_reasoning_depth=max_reasoning_depth, |
|
|
) |
|
|
|
|
|
logits = outputs["logits"][0] |
|
|
thinking_results = outputs["thinking_results"] |
|
|
reasoning_tokens = outputs.get("final_tokens", 0) |
|
|
|
|
|
|
|
|
generated_tokens = [] |
|
|
current_logits = logits[-1] |
|
|
|
|
|
for _ in range(max_tokens): |
|
|
if temperature > 0: |
|
|
probs = torch.softmax(current_logits / temperature, dim=-1) |
|
|
next_token = torch.multinomial(probs, 1).item() |
|
|
else: |
|
|
next_token = current_logits.argmax().item() |
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
|
|
|
if next_token == tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
|
|
|
if len(generated_tokens) < max_tokens: |
|
|
current_logits = current_logits |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(generated_tokens) |
|
|
|
|
|
return { |
|
|
"generated_text": generated_text, |
|
|
"thinking_results": thinking_results, |
|
|
"reasoning_tokens": reasoning_tokens, |
|
|
"input_tokens": len(input_ids), |
|
|
"output_tokens": len(generated_tokens), |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Generation error: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
|
|
async def chat_completions(request: ChatCompletionRequest): |
|
|
"""OpenAI-compatible chat completions endpoint.""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
user_messages = [msg for msg in request.messages if msg.role == "user"] |
|
|
if not user_messages: |
|
|
raise HTTPException(status_code=400, detail="No user message found") |
|
|
|
|
|
prompt = user_messages[-1].content |
|
|
|
|
|
|
|
|
system_messages = [msg for msg in request.messages if msg.role == "system"] |
|
|
if system_messages: |
|
|
prompt = f"System: {system_messages[0].content}\n\n{prompt}" |
|
|
|
|
|
|
|
|
result = generate_text( |
|
|
prompt=prompt, |
|
|
max_tokens=request.max_tokens or 100, |
|
|
temperature=request.temperature or 0.7, |
|
|
reasoning_depth=request.reasoning_depth or "adaptive", |
|
|
early_stop_threshold=request.early_stop_threshold or 0.85, |
|
|
) |
|
|
|
|
|
|
|
|
thinking_trace = None |
|
|
if request.thinking_visualization and result["thinking_results"]: |
|
|
thinking_trace = { |
|
|
"reasoning_paths": len(result["thinking_results"]), |
|
|
"reasoning_tokens": result["reasoning_tokens"], |
|
|
"confidence_scores": [0.85, 0.78, 0.92], |
|
|
} |
|
|
|
|
|
response = ChatCompletionResponse( |
|
|
id=f"chatcmpl-{int(time.time())}", |
|
|
created=int(time.time()), |
|
|
model=request.model, |
|
|
choices=[ |
|
|
ChatCompletionChoice( |
|
|
index=0, |
|
|
message=ChatMessage(role="assistant", content=result["generated_text"]), |
|
|
finish_reason="stop", |
|
|
thinking_trace=thinking_trace, |
|
|
) |
|
|
], |
|
|
usage={ |
|
|
"prompt_tokens": result["input_tokens"], |
|
|
"completion_tokens": result["output_tokens"], |
|
|
"total_tokens": result["input_tokens"] + result["output_tokens"], |
|
|
"reasoning_tokens": result["reasoning_tokens"], |
|
|
} |
|
|
) |
|
|
|
|
|
logger.info(f"Chat completion took {time.time() - start_time:.2f}s") |
|
|
return response |
|
|
|
|
|
|
|
|
@app.post("/v1/completions", response_model=CompletionResponse) |
|
|
async def completions(request: CompletionRequest): |
|
|
"""OpenAI-compatible text completions endpoint.""" |
|
|
start_time = time.time() |
|
|
|
|
|
result = generate_text( |
|
|
prompt=request.prompt, |
|
|
max_tokens=request.max_tokens or 50, |
|
|
temperature=request.temperature or 0.8, |
|
|
reasoning_depth=2, |
|
|
early_stop_threshold=0.8, |
|
|
) |
|
|
|
|
|
response = CompletionResponse( |
|
|
id=f"cmpl-{int(time.time())}", |
|
|
created=int(time.time()), |
|
|
model=request.model, |
|
|
choices=[ |
|
|
CompletionChoice( |
|
|
text=result["generated_text"], |
|
|
index=0, |
|
|
finish_reason="stop", |
|
|
thinking_tokens=result["reasoning_tokens"], |
|
|
) |
|
|
], |
|
|
usage={ |
|
|
"prompt_tokens": result["input_tokens"], |
|
|
"completion_tokens": result["output_tokens"], |
|
|
"total_tokens": result["input_tokens"] + result["output_tokens"], |
|
|
} |
|
|
) |
|
|
|
|
|
logger.info(f"Completion took {time.time() - start_time:.2f}s") |
|
|
return response |
|
|
|
|
|
|
|
|
@app.post("/v1/messages", response_model=AnthropicMessageResponse) |
|
|
async def anthropic_messages(request: AnthropicMessageRequest): |
|
|
"""Anthropic-compatible messages endpoint.""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
messages = [] |
|
|
for msg in request.messages: |
|
|
if msg.role == "user": |
|
|
messages.append(f"Human: {msg.content}") |
|
|
elif msg.role == "assistant": |
|
|
messages.append(f"Assistant: {msg.content}") |
|
|
|
|
|
|
|
|
if request.system: |
|
|
messages.insert(0, f"System: {request.system}") |
|
|
|
|
|
prompt = "\n\n".join(messages) |
|
|
|
|
|
|
|
|
thinking_config = request.thinking_config or {} |
|
|
reasoning_depth = thinking_config.get("reasoning_depth", "complex") |
|
|
visualization = thinking_config.get("thinking_visualization", True) |
|
|
|
|
|
result = generate_text( |
|
|
prompt=prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=0.7, |
|
|
reasoning_depth=reasoning_depth, |
|
|
early_stop_threshold=0.85, |
|
|
) |
|
|
|
|
|
|
|
|
content = [{"type": "text", "text": result["generated_text"]}] |
|
|
|
|
|
if visualization and result["thinking_results"]: |
|
|
thinking_text = f"\n\nThinking process used {result['reasoning_tokens']} reasoning tokens across {len(result['thinking_results'])} layers." |
|
|
content.insert(0, {"type": "text", "text": thinking_text}) |
|
|
|
|
|
response = AnthropicMessageResponse( |
|
|
id=f"msg_{int(time.time())}", |
|
|
model=request.model, |
|
|
content=content, |
|
|
usage={ |
|
|
"input_tokens": result["input_tokens"], |
|
|
"output_tokens": result["output_tokens"], |
|
|
"total_tokens": result["input_tokens"] + result["output_tokens"], |
|
|
} |
|
|
) |
|
|
|
|
|
logger.info(f"Anthropic message took {time.time() - start_time:.2f}s") |
|
|
return response |
|
|
|
|
|
|
|
|
@app.get("/v1/models", response_model=ModelListResponse) |
|
|
async def list_models(): |
|
|
"""List available models.""" |
|
|
return ModelListResponse( |
|
|
data=[ |
|
|
ModelInfo( |
|
|
id="compact-ai-v1", |
|
|
created=int(time.time()), |
|
|
) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/v1/models/{model_id}") |
|
|
async def get_model(model_id: str): |
|
|
"""Get model information.""" |
|
|
if model_id != "compact-ai-v1": |
|
|
raise HTTPException(status_code=404, detail="Model not found") |
|
|
|
|
|
return ModelInfo( |
|
|
id=model_id, |
|
|
created=int(time.time()), |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
memory_info = psutil.virtual_memory() |
|
|
gpu_info = {} |
|
|
|
|
|
try: |
|
|
gpus = GPUtil.getGPUs() |
|
|
if gpus: |
|
|
gpu = gpus[0] |
|
|
gpu_info = { |
|
|
"gpu_name": gpu.name, |
|
|
"gpu_memory_used": gpu.memoryUsed, |
|
|
"gpu_memory_total": gpu.memoryTotal, |
|
|
"gpu_memory_free": gpu.memoryFree, |
|
|
"gpu_utilization": gpu.load * 100, |
|
|
} |
|
|
except: |
|
|
pass |
|
|
|
|
|
return HealthResponse( |
|
|
status="healthy" if model is not None else "unhealthy", |
|
|
model_loaded=model is not None, |
|
|
gpu_available=torch.cuda.is_available(), |
|
|
memory_usage={ |
|
|
"ram_used": memory_info.used, |
|
|
"ram_total": memory_info.total, |
|
|
"ram_percent": memory_info.percent, |
|
|
**gpu_info, |
|
|
}, |
|
|
uptime=str(datetime.now() - datetime.fromtimestamp(psutil.boot_time())), |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint.""" |
|
|
return {"message": "Compact AI Model API", "version": "1.0.0"} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run Compact AI Model API") |
|
|
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") |
|
|
parser.add_argument("--port", type=int, default=8000, help="Port to bind to") |
|
|
parser.add_argument("--workers", type=int, default=1, help="Number of workers") |
|
|
parser.add_argument("--model-size", default="small", choices=["tiny", "small", "medium"], help="Model size") |
|
|
parser.add_argument("--checkpoint", help="Path to model checkpoint") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.environ["MODEL_SIZE"] = args.model_size |
|
|
if args.checkpoint: |
|
|
os.environ["MODEL_CHECKPOINT"] = args.checkpoint |
|
|
|
|
|
uvicorn.run( |
|
|
"main:app", |
|
|
host=args.host, |
|
|
port=args.port, |
|
|
workers=args.workers, |
|
|
reload=False, |
|
|
log_level="info", |
|
|
) |