Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import subprocess | |
| import signal | |
| import os | |
| import requests | |
| import time | |
| from typing import Optional | |
| app = FastAPI() | |
| # Predefined list of available models | |
| AVAILABLE_MODELS = { | |
| # === Financial & Summarization Models (Recommended) === | |
| "qwen-2.5-7b": "bartowski/Qwen2.5-7B-Instruct-GGUF:Qwen2.5-7B-Instruct-Q4_K_M.gguf", # Best for financial + multilingual | |
| "kimi-k2-9b": "bartowski/k2-chat-GGUF:k2-chat-Q4_K_M.gguf", # Kimi K2 - long context, good reasoning | |
| "yi-1.5-9b": "bartowski/Yi-1.5-9B-Chat-GGUF:Yi-1.5-9B-Chat-Q4_K_M.gguf", # Excellent for finance | |
| "llama-3.1-8b": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", # Great reasoning | |
| "mistral-7b": "TheBloke/Mistral-7B-Instruct-v0.3-GGUF:mistral-7b-instruct-v0.3.Q4_K_M.gguf", # Reliable summarization | |
| # === Coding Models === | |
| "deepseek-coder": "TheBloke/deepseek-coder-6.7B-instruct-GGUF:deepseek-coder-6.7b-instruct.Q4_K_M.gguf", | |
| # === General Purpose === | |
| "deepseek-chat": "TheBloke/deepseek-llm-7B-chat-GGUF:deepseek-llm-7b-chat.Q4_K_M.gguf", | |
| "llama-3.2-3b": "bartowski/Llama-3.2-3B-Instruct-GGUF:Llama-3.2-3B-Instruct-Q4_K_M.gguf", # Fast & lightweight | |
| } | |
| # Global state | |
| current_model = "deepseek-chat" # Default model | |
| llama_process: Optional[subprocess.Popen] = None | |
| LLAMA_SERVER_PORT = 8080 | |
| LLAMA_SERVER_URL = f"http://localhost:{LLAMA_SERVER_PORT}" | |
| class ModelSwitchRequest(BaseModel): | |
| model_name: str | |
| class ChatCompletionRequest(BaseModel): | |
| messages: list[dict] | |
| max_tokens: int = 256 | |
| temperature: float = 0.7 | |
| def start_llama_server(model_id: str) -> subprocess.Popen: | |
| """Start llama-server with specified model (optimized for speed).""" | |
| cmd = [ | |
| "llama-server", | |
| "-hf", model_id, | |
| "--host", "0.0.0.0", | |
| "--port", str(LLAMA_SERVER_PORT), | |
| "-c", "2048", # Context size | |
| "-t", "4", # CPU threads (adjust based on cores) | |
| "-ngl", "0", # GPU layers (0 for CPU-only) | |
| "--cont-batching", # Enable continuous batching for speed | |
| "-b", "512", # Batch size | |
| ] | |
| print(f"Starting llama-server with model: {model_id}") | |
| print("This may take 2-3 minutes to download and load the model...") | |
| process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| preexec_fn=os.setsid if os.name != 'nt' else None, | |
| text=True, | |
| bufsize=1 | |
| ) | |
| # Wait for server to be ready (increased timeout for model download) | |
| max_retries = 300 # 5 minutes | |
| for i in range(max_retries): | |
| # Check if process died | |
| if process.poll() is not None: | |
| stdout, _ = process.communicate() | |
| print(f"llama-server exited with code {process.returncode}") | |
| print(f"Output: {stdout}") | |
| raise RuntimeError("llama-server process died") | |
| try: | |
| # Try root endpoint instead of /health | |
| response = requests.get(f"{LLAMA_SERVER_URL}/", timeout=2) | |
| if response.status_code in [200, 404]: # 404 is ok, means server is up | |
| print(f"llama-server ready after {i+1} seconds") | |
| return process | |
| except requests.exceptions.ConnectionError: | |
| # Server not ready yet | |
| pass | |
| except Exception as e: | |
| # Other errors, keep waiting | |
| pass | |
| time.sleep(1) | |
| raise RuntimeError("llama-server failed to start within 5 minutes") | |
| def stop_llama_server(): | |
| """Stop the running llama-server.""" | |
| global llama_process | |
| if llama_process: | |
| print("Stopping llama-server...") | |
| try: | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(llama_process.pid), signal.SIGTERM) | |
| else: | |
| llama_process.terminate() | |
| llama_process.wait(timeout=10) | |
| except: | |
| if os.name != 'nt': | |
| os.killpg(os.getpgid(llama_process.pid), signal.SIGKILL) | |
| else: | |
| llama_process.kill() | |
| llama_process = None | |
| time.sleep(2) # Give it time to fully shut down | |
| async def startup_event(): | |
| """Start with default model.""" | |
| global llama_process | |
| model_id = AVAILABLE_MODELS[current_model] | |
| llama_process = start_llama_server(model_id) | |
| async def shutdown_event(): | |
| """Clean shutdown.""" | |
| stop_llama_server() | |
| async def root(): | |
| return { | |
| "status": "DeepSeek API with dynamic model switching", | |
| "current_model": current_model, | |
| "available_models": list(AVAILABLE_MODELS.keys()) | |
| } | |
| async def list_models(): | |
| """List all available models.""" | |
| return { | |
| "current_model": current_model, | |
| "available_models": list(AVAILABLE_MODELS.keys()) | |
| } | |
| async def switch_model(request: ModelSwitchRequest): | |
| """Switch to a different model.""" | |
| global current_model, llama_process | |
| if request.model_name not in AVAILABLE_MODELS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Model '{request.model_name}' not found. Available: {list(AVAILABLE_MODELS.keys())}" | |
| ) | |
| if request.model_name == current_model: | |
| return {"message": f"Already using model: {current_model}"} | |
| # Stop current server | |
| stop_llama_server() | |
| # Start with new model | |
| model_id = AVAILABLE_MODELS[request.model_name] | |
| llama_process = start_llama_server(model_id) | |
| current_model = request.model_name | |
| return { | |
| "message": f"Switched to model: {current_model}", | |
| "model": current_model | |
| } | |
| async def chat_completions(request: ChatCompletionRequest): | |
| """OpenAI-compatible chat completions endpoint.""" | |
| try: | |
| # Forward to llama-server | |
| response = requests.post( | |
| f"{LLAMA_SERVER_URL}/v1/chat/completions", | |
| json={ | |
| "messages": request.messages, | |
| "max_tokens": request.max_tokens, | |
| "temperature": request.temperature, | |
| }, | |
| timeout=300 | |
| ) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}") |