Spaces:
Paused
Paused
| import os | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Dict | |
| from datetime import datetime | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.responses import StreamingResponse, HTMLResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| # ============================================================================ | |
| # SETUP & CONFIG | |
| # ============================================================================ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="LLM Chat API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Persistent storage: HF mounts the bucket at /data. | |
| # Fall back to home dir if not mounted (local dev). | |
| _DATA_DIR = Path("/data") if Path("/data").exists() else Path.home() / "data" | |
| MODEL_CACHE_DIR = _DATA_DIR / "models" | |
| MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Model cache dir: {MODEL_CACHE_DIR}") | |
| # ============================================================================ | |
| # MODEL REGISTRY | |
| # ============================================================================ | |
| MODELS_CONFIG = { | |
| "qwen-3b": { | |
| "name": "Qwen 2.5 3B Instruct", | |
| "repo": "Qwen/Qwen2.5-3B-Instruct-GGUF", | |
| "file": "qwen2.5-3b-instruct-q4_k_m.gguf", | |
| "context_size": 32768, | |
| "chat_format": "chatml", | |
| "description": "Fast 3B model with 32k context", | |
| "size": "2.5GB", | |
| }, | |
| # Uncomment to add more (watch RAM — free tier has 16GB total): | |
| # "qwen-7b": { | |
| # "name": "Qwen 2.5 7B Instruct", | |
| # "repo": "Qwen/Qwen2.5-7B-Instruct-GGUF", | |
| # "file": "qwen2.5-7b-instruct-q3_k_m.gguf", | |
| # "context_size": 32768, | |
| # "chat_format": "chatml", | |
| # "description": "Stronger 7B, slower on CPU", | |
| # "size": "4.5GB", | |
| # }, | |
| } | |
| DEFAULT_MODEL = "qwen-3b" | |
| loaded_models: Dict[str, Llama] = {} | |
| current_model_id = DEFAULT_MODEL | |
| # ============================================================================ | |
| # REQUEST / RESPONSE MODELS | |
| # ============================================================================ | |
| class ChatMessage(BaseModel): | |
| role: str # "system" | "user" | "assistant" | |
| content: str | |
| class ChatRequest(BaseModel): | |
| messages: list[ChatMessage] | |
| model: str = DEFAULT_MODEL | |
| max_tokens: int = 512 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| repeat_penalty: float = 1.1 | |
| stream: bool = False | |
| # ============================================================================ | |
| # MODEL LOADING | |
| # ============================================================================ | |
| def download_model(model_id: str) -> Path: | |
| config = MODELS_CONFIG[model_id] | |
| model_path = MODEL_CACHE_DIR / config["file"] | |
| if model_path.exists(): | |
| logger.info(f"Cache hit: {model_path}") | |
| return model_path | |
| logger.info(f"Downloading {config['name']} from {config['repo']} ...") | |
| from huggingface_hub import hf_hub_download | |
| path = hf_hub_download( | |
| repo_id=config["repo"], | |
| filename=config["file"], | |
| local_dir=str(MODEL_CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| ) | |
| logger.info(f"Download complete → {path}") | |
| return Path(path) | |
| def load_model(model_id: str) -> Llama: | |
| global current_model_id | |
| if model_id in loaded_models: | |
| current_model_id = model_id | |
| return loaded_models[model_id] | |
| if model_id not in MODELS_CONFIG: | |
| raise ValueError(f"Unknown model: {model_id}") | |
| config = MODELS_CONFIG[model_id] | |
| model_path = download_model(model_id) | |
| logger.info(f"Loading {model_id} ...") | |
| llm = Llama( | |
| model_path=str(model_path), | |
| n_gpu_layers=0, # CPU only on free tier | |
| n_ctx=config["context_size"], | |
| n_threads=2, # Match free-tier vCPU count exactly | |
| n_batch=512, | |
| chat_format=config["chat_format"], | |
| verbose=False, | |
| ) | |
| loaded_models[model_id] = llm | |
| current_model_id = model_id | |
| logger.info(f"{model_id} ready") | |
| return llm | |
| def get_model(model_id: Optional[str] = None) -> Llama: | |
| mid = model_id or current_model_id | |
| if mid not in loaded_models: | |
| load_model(mid) | |
| return loaded_models[mid] | |
| async def startup_event(): | |
| load_model(DEFAULT_MODEL) | |
| # ============================================================================ | |
| # STREAMING HELPER | |
| # ============================================================================ | |
| async def _stream_completion(llm: Llama, kwargs: dict): | |
| """Yield SSE chunks in OpenAI streaming format.""" | |
| try: | |
| for chunk in llm.create_chat_completion(**kwargs, stream=True): | |
| delta = chunk["choices"][0].get("delta", {}) | |
| if delta.get("content"): | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| except Exception as e: | |
| logger.error(f"Stream error: {e}") | |
| error_payload = {"error": {"message": str(e), "type": "server_error"}} | |
| yield f"data: {json.dumps(error_payload)}\n\n" | |
| # ============================================================================ | |
| # API ROUTES | |
| # ============================================================================ | |
| async def root(): | |
| """Minimal status page — useful when you open the Space URL in a browser.""" | |
| model_rows = "".join( | |
| f"<tr><td>{mid}</td><td>{cfg['name']}</td><td>{cfg['size']}</td>" | |
| f"<td>{'✅ loaded' if mid in loaded_models else '—'}</td></tr>" | |
| for mid, cfg in MODELS_CONFIG.items() | |
| ) | |
| return f"""<!DOCTYPE html> | |
| <html><head><title>LLM API</title> | |
| <style> | |
| body {{ font-family: sans-serif; max-width: 700px; margin: 60px auto; color: #e2e8f0; background: #0f172a; }} | |
| h1 {{ color: #06b6d4; }} code {{ background: #1e293b; padding: 2px 6px; border-radius: 4px; }} | |
| table {{ border-collapse: collapse; width: 100%; margin-top: 16px; }} | |
| th, td {{ text-align: left; padding: 8px 12px; border-bottom: 1px solid #334155; }} | |
| th {{ color: #94a3b8; font-size: 12px; text-transform: uppercase; }} | |
| </style></head><body> | |
| <h1>🤖 LLM Chat API</h1> | |
| <p>OpenAI-compatible endpoint. Point SillyTavern here.</p> | |
| <h3>SillyTavern setup</h3> | |
| <ul> | |
| <li>API: <code>Chat Completion</code></li> | |
| <li>Source: <code>Custom (OpenAI-compatible)</code></li> | |
| <li>Server URL: <code>{{}YOUR_SPACE_URL{{}}</code></li> | |
| <li>Model: <code>{DEFAULT_MODEL}</code></li> | |
| <li>API Key: <code>anything</code> (not checked)</li> | |
| </ul> | |
| <h3>Endpoints</h3> | |
| <ul> | |
| <li><code>GET /health</code></li> | |
| <li><code>GET /v1/models</code></li> | |
| <li><code>POST /v1/chat/completions</code></li> | |
| </ul> | |
| <h3>Models</h3> | |
| <table> | |
| <tr><th>ID</th><th>Name</th><th>Size</th><th>Status</th></tr> | |
| {model_rows} | |
| </table> | |
| </body></html>""" | |
| async def health(): | |
| return { | |
| "status": "healthy", | |
| "current_model": current_model_id, | |
| "models_loaded": list(loaded_models.keys()), | |
| "cache_dir": str(MODEL_CACHE_DIR), | |
| } | |
| async def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": mid, | |
| "object": "model", | |
| "created": int(datetime.now().timestamp()), | |
| "owned_by": "local", | |
| "context_length": cfg["context_size"], | |
| "description": cfg["description"], | |
| "loaded": mid in loaded_models, | |
| } | |
| for mid, cfg in MODELS_CONFIG.items() | |
| ], | |
| } | |
| async def chat_completions(request: ChatRequest): | |
| if request.model not in MODELS_CONFIG: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {request.model}") | |
| try: | |
| llm = get_model(request.model) | |
| except Exception as e: | |
| raise HTTPException(status_code=503, detail=f"Model unavailable: {e}") | |
| messages = [{"role": m.role, "content": m.content} for m in request.messages] | |
| # Only stop on real chat template boundary tokens — never on \n\n | |
| stop_tokens = ["<|im_end|>", "<|im_start|>"] | |
| kwargs = dict( | |
| messages=messages, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| repeat_penalty=request.repeat_penalty, | |
| stop=stop_tokens, | |
| ) | |
| if request.stream: | |
| return StreamingResponse( | |
| _stream_completion(llm, kwargs), | |
| media_type="text/event-stream", | |
| headers={"X-Accel-Buffering": "no"}, | |
| ) | |
| output = llm.create_chat_completion(**kwargs) | |
| return output # already OpenAI-compatible from llama-cpp | |
| # ============================================================================ | |
| # ENTRYPOINT | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |