Spaces:
Sleeping
Sleeping
File size: 6,501 Bytes
dde400a 9d0ed97 dde400a 9d0ed97 dde400a 9d0ed97 dde400a 9d0ed97 dde400a e80973f dde400a e80973f 9d0ed97 e80973f dde400a e80973f dde400a e80973f dde400a e80973f dde400a e80973f dde400a 9d0ed97 dde400a 9d0ed97 dde400a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import subprocess
import signal
import os
import requests
import time
from typing import Optional
app = FastAPI()
# Predefined list of available models
AVAILABLE_MODELS = {
# === Financial & Summarization Models (Recommended) ===
"qwen-2.5-7b": "bartowski/Qwen2.5-7B-Instruct-GGUF:Qwen2.5-7B-Instruct-Q4_K_M.gguf", # Best for financial + multilingual
"kimi-k2-9b": "bartowski/k2-chat-GGUF:k2-chat-Q4_K_M.gguf", # Kimi K2 - long context, good reasoning
"yi-1.5-9b": "bartowski/Yi-1.5-9B-Chat-GGUF:Yi-1.5-9B-Chat-Q4_K_M.gguf", # Excellent for finance
"llama-3.1-8b": "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", # Great reasoning
"mistral-7b": "TheBloke/Mistral-7B-Instruct-v0.3-GGUF:mistral-7b-instruct-v0.3.Q4_K_M.gguf", # Reliable summarization
# === Coding Models ===
"deepseek-coder": "TheBloke/deepseek-coder-6.7B-instruct-GGUF:deepseek-coder-6.7b-instruct.Q4_K_M.gguf",
# === General Purpose ===
"deepseek-chat": "TheBloke/deepseek-llm-7B-chat-GGUF:deepseek-llm-7b-chat.Q4_K_M.gguf",
"llama-3.2-3b": "bartowski/Llama-3.2-3B-Instruct-GGUF:Llama-3.2-3B-Instruct-Q4_K_M.gguf", # Fast & lightweight
}
# Global state
current_model = "deepseek-chat" # Default model
llama_process: Optional[subprocess.Popen] = None
LLAMA_SERVER_PORT = 8080
LLAMA_SERVER_URL = f"http://localhost:{LLAMA_SERVER_PORT}"
class ModelSwitchRequest(BaseModel):
model_name: str
class ChatCompletionRequest(BaseModel):
messages: list[dict]
max_tokens: int = 256
temperature: float = 0.7
def start_llama_server(model_id: str) -> subprocess.Popen:
"""Start llama-server with specified model (optimized for speed)."""
cmd = [
"llama-server",
"-hf", model_id,
"--host", "0.0.0.0",
"--port", str(LLAMA_SERVER_PORT),
"-c", "2048", # Context size
"-t", "4", # CPU threads (adjust based on cores)
"-ngl", "0", # GPU layers (0 for CPU-only)
"--cont-batching", # Enable continuous batching for speed
"-b", "512", # Batch size
]
print(f"Starting llama-server with model: {model_id}")
print("This may take 2-3 minutes to download and load the model...")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
preexec_fn=os.setsid if os.name != 'nt' else None,
text=True,
bufsize=1
)
# Wait for server to be ready (increased timeout for model download)
max_retries = 300 # 5 minutes
for i in range(max_retries):
# Check if process died
if process.poll() is not None:
stdout, _ = process.communicate()
print(f"llama-server exited with code {process.returncode}")
print(f"Output: {stdout}")
raise RuntimeError("llama-server process died")
try:
# Try root endpoint instead of /health
response = requests.get(f"{LLAMA_SERVER_URL}/", timeout=2)
if response.status_code in [200, 404]: # 404 is ok, means server is up
print(f"llama-server ready after {i+1} seconds")
return process
except requests.exceptions.ConnectionError:
# Server not ready yet
pass
except Exception as e:
# Other errors, keep waiting
pass
time.sleep(1)
raise RuntimeError("llama-server failed to start within 5 minutes")
def stop_llama_server():
"""Stop the running llama-server."""
global llama_process
if llama_process:
print("Stopping llama-server...")
try:
if os.name != 'nt':
os.killpg(os.getpgid(llama_process.pid), signal.SIGTERM)
else:
llama_process.terminate()
llama_process.wait(timeout=10)
except:
if os.name != 'nt':
os.killpg(os.getpgid(llama_process.pid), signal.SIGKILL)
else:
llama_process.kill()
llama_process = None
time.sleep(2) # Give it time to fully shut down
@app.on_event("startup")
async def startup_event():
"""Start with default model."""
global llama_process
model_id = AVAILABLE_MODELS[current_model]
llama_process = start_llama_server(model_id)
@app.on_event("shutdown")
async def shutdown_event():
"""Clean shutdown."""
stop_llama_server()
@app.get("/")
async def root():
return {
"status": "DeepSeek API with dynamic model switching",
"current_model": current_model,
"available_models": list(AVAILABLE_MODELS.keys())
}
@app.get("/models")
async def list_models():
"""List all available models."""
return {
"current_model": current_model,
"available_models": list(AVAILABLE_MODELS.keys())
}
@app.post("/switch-model")
async def switch_model(request: ModelSwitchRequest):
"""Switch to a different model."""
global current_model, llama_process
if request.model_name not in AVAILABLE_MODELS:
raise HTTPException(
status_code=400,
detail=f"Model '{request.model_name}' not found. Available: {list(AVAILABLE_MODELS.keys())}"
)
if request.model_name == current_model:
return {"message": f"Already using model: {current_model}"}
# Stop current server
stop_llama_server()
# Start with new model
model_id = AVAILABLE_MODELS[request.model_name]
llama_process = start_llama_server(model_id)
current_model = request.model_name
return {
"message": f"Switched to model: {current_model}",
"model": current_model
}
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""OpenAI-compatible chat completions endpoint."""
try:
# Forward to llama-server
response = requests.post(
f"{LLAMA_SERVER_URL}/v1/chat/completions",
json={
"messages": request.messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
},
timeout=300
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"llama-server error: {str(e)}") |