import os import json import logging from pathlib import Path from typing import Optional, Dict from datetime import datetime from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from llama_cpp import Llama # ============================================================================ # SETUP & CONFIG # ============================================================================ logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) logger = logging.getLogger(__name__) app = FastAPI(title="LLM Chat API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Persistent storage: HF mounts the bucket at /data. # Fall back to home dir if not mounted (local dev). _DATA_DIR = Path("/data") if Path("/data").exists() else Path.home() / "data" MODEL_CACHE_DIR = _DATA_DIR / "models" MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) logger.info(f"Model cache dir: {MODEL_CACHE_DIR}") # ============================================================================ # MODEL REGISTRY # ============================================================================ MODELS_CONFIG = { "qwen-3b": { "name": "Qwen 2.5 3B Instruct", "repo": "Qwen/Qwen2.5-3B-Instruct-GGUF", "file": "qwen2.5-3b-instruct-q4_k_m.gguf", "context_size": 32768, "chat_format": "chatml", "description": "Fast 3B model with 32k context", "size": "2.5GB", }, # Uncomment to add more (watch RAM — free tier has 16GB total): # "qwen-7b": { # "name": "Qwen 2.5 7B Instruct", # "repo": "Qwen/Qwen2.5-7B-Instruct-GGUF", # "file": "qwen2.5-7b-instruct-q3_k_m.gguf", # "context_size": 32768, # "chat_format": "chatml", # "description": "Stronger 7B, slower on CPU", # "size": "4.5GB", # }, } DEFAULT_MODEL = "qwen-3b" loaded_models: Dict[str, Llama] = {} current_model_id = DEFAULT_MODEL # ============================================================================ # REQUEST / RESPONSE MODELS # ============================================================================ class ChatMessage(BaseModel): role: str # "system" | "user" | "assistant" content: str class ChatRequest(BaseModel): messages: list[ChatMessage] model: str = DEFAULT_MODEL max_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.9 repeat_penalty: float = 1.1 stream: bool = False # ============================================================================ # MODEL LOADING # ============================================================================ def download_model(model_id: str) -> Path: config = MODELS_CONFIG[model_id] model_path = MODEL_CACHE_DIR / config["file"] if model_path.exists(): logger.info(f"Cache hit: {model_path}") return model_path logger.info(f"Downloading {config['name']} from {config['repo']} ...") from huggingface_hub import hf_hub_download path = hf_hub_download( repo_id=config["repo"], filename=config["file"], local_dir=str(MODEL_CACHE_DIR), local_dir_use_symlinks=False, ) logger.info(f"Download complete → {path}") return Path(path) def load_model(model_id: str) -> Llama: global current_model_id if model_id in loaded_models: current_model_id = model_id return loaded_models[model_id] if model_id not in MODELS_CONFIG: raise ValueError(f"Unknown model: {model_id}") config = MODELS_CONFIG[model_id] model_path = download_model(model_id) logger.info(f"Loading {model_id} ...") llm = Llama( model_path=str(model_path), n_gpu_layers=0, # CPU only on free tier n_ctx=config["context_size"], n_threads=2, # Match free-tier vCPU count exactly n_batch=512, chat_format=config["chat_format"], verbose=False, ) loaded_models[model_id] = llm current_model_id = model_id logger.info(f"{model_id} ready") return llm def get_model(model_id: Optional[str] = None) -> Llama: mid = model_id or current_model_id if mid not in loaded_models: load_model(mid) return loaded_models[mid] @app.on_event("startup") async def startup_event(): load_model(DEFAULT_MODEL) # ============================================================================ # STREAMING HELPER # ============================================================================ async def _stream_completion(llm: Llama, kwargs: dict): """Yield SSE chunks in OpenAI streaming format.""" try: for chunk in llm.create_chat_completion(**kwargs, stream=True): delta = chunk["choices"][0].get("delta", {}) if delta.get("content"): yield f"data: {json.dumps(chunk)}\n\n" yield "data: [DONE]\n\n" except Exception as e: logger.error(f"Stream error: {e}") error_payload = {"error": {"message": str(e), "type": "server_error"}} yield f"data: {json.dumps(error_payload)}\n\n" # ============================================================================ # API ROUTES # ============================================================================ @app.get("/", response_class=HTMLResponse) async def root(): """Minimal status page — useful when you open the Space URL in a browser.""" model_rows = "".join( f"
OpenAI-compatible endpoint. Point SillyTavern here.
Chat CompletionCustom (OpenAI-compatible){{}YOUR_SPACE_URL{{}}{DEFAULT_MODEL}anything (not checked)GET /healthGET /v1/modelsPOST /v1/chat/completions| ID | Name | Size | Status |
|---|