Spaces:
Paused
Paused
| import os | |
| from typing import List, Tuple, Optional | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from huggingface_hub import hf_hub_download | |
| from llama_cpp import Llama | |
| from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType | |
| from llama_cpp_agent.providers import LlamaCppPythonProvider | |
| from llama_cpp_agent.chat_history import BasicChatHistory | |
| from llama_cpp_agent.chat_history.messages import Roles | |
| # Suppress warnings | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # Ensure models directory exists | |
| MODEL_DIR = "./models" | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| # Model info for download | |
| MODELS_INFO = [ | |
| { | |
| "repo_id": "bartowski/Dolphin3.0-Llama3.2-1B-GGUF", | |
| "filename": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf" | |
| }, | |
| { | |
| "repo_id": "bartowski/Dolphin3.0-Qwen2.5-0.5B-GGUF", | |
| "filename": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf" | |
| }, | |
| { | |
| "repo_id": "bartowski/Qwen2.5-Coder-14B-Instruct-GGUF", | |
| "filename": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf" | |
| } | |
| ] | |
| # Download all models if not present | |
| for model_info in MODELS_INFO: | |
| model_path = os.path.join(MODEL_DIR, model_info["filename"]) | |
| if not os.path.exists(model_path): | |
| print(f"Downloading {model_info['filename']} from {model_info['repo_id']}...") | |
| try: | |
| hf_hub_download( | |
| repo_id=model_info["repo_id"], | |
| filename=model_info["filename"], | |
| local_dir=MODEL_DIR | |
| ) | |
| print(f"Downloaded {model_info['filename']}") | |
| except Exception as e: | |
| print(f"Error downloading {model_info['filename']}: {e}") | |
| # Available model keys (used in API) | |
| AVAILABLE_MODELS = { | |
| "qwen": "Dolphin3.0-Qwen2.5-0.5B-Q6_K.gguf", | |
| "llama": "Dolphin3.0-Llama3.2-1B-Q4_K_M.gguf", | |
| "coder": "Qwen2.5-Coder-14B-Instruct-Q6_K.gguf" | |
| } | |
| # Global LLM instance | |
| llm = None | |
| llm_model = None | |
| def load_model(model_key: str): | |
| global llm, llm_model | |
| model_name = AVAILABLE_MODELS.get(model_key) | |
| if not model_name: | |
| raise ValueError(f"Invalid model key: {model_key}") | |
| model_path = os.path.join(MODEL_DIR, model_name) | |
| if not os.path.exists(model_path): | |
| raise FileNotFoundError(f"Model file not found at {model_path}") | |
| if llm is None or llm_model != model_name: | |
| llm = Llama( | |
| model_path=model_path, | |
| flash_attn=False, | |
| n_gpu_layers=0, | |
| n_batch=8, | |
| n_ctx=2048, | |
| n_threads=8, | |
| n_threads_batch=8, | |
| ) | |
| llm_model = model_name | |
| return llm | |
| class ChatRequest(BaseModel): | |
| message: str # Required | |
| history: Optional[List[Tuple[str, str]]] = [] # Default: empty list | |
| model: Optional[str] = "qwen" # Default model key | |
| system_prompt: Optional[str] = "You are Dolphin, a helpful AI assistant." | |
| max_tokens: Optional[int] = 1024 | |
| temperature: Optional[float] = 0.7 | |
| top_p: Optional[float] = 0.95 | |
| top_k: Optional[int] = 40 | |
| repeat_penalty: Optional[float] = 1.1 | |
| class ChatResponse(BaseModel): | |
| response: str | |
| class ModelInfoResponse(BaseModel): | |
| models: List[str] | |
| app = FastAPI( | |
| title="Dolphin 3.0 LLM API", | |
| description="REST API for Dolphin 3.0 models using Llama.cpp backend.", | |
| version="1.0", | |
| docs_url="/docs", # Only Swagger docs | |
| redoc_url=None # Disable ReDoc | |
| ) | |
| def get_available_models(): | |
| """Returns the list of supported models.""" | |
| return {"models": list(AVAILABLE_MODELS.keys())} | |
| def chat(request: ChatRequest): | |
| try: | |
| # Load model | |
| load_model(request.model) | |
| provider = LlamaCppPythonProvider(llm) | |
| agent = LlamaCppAgent( | |
| provider, | |
| system_prompt=request.system_prompt, | |
| predefined_messages_formatter_type=MessagesFormatterType.CHATML, | |
| ) | |
| settings = provider.get_provider_default_settings() | |
| settings.temperature = request.temperature | |
| settings.top_k = request.top_k | |
| settings.top_p = request.top_p | |
| settings.max_tokens = request.max_tokens | |
| settings.repeat_penalty = request.repeat_penalty | |
| messages = BasicChatHistory() | |
| # Add history | |
| for user_msg, assistant_msg in request.history: | |
| messages.add_message({"role": Roles.user, "content": user_msg}) | |
| messages.add_message({"role": Roles.assistant, "content": assistant_msg}) | |
| # Get response | |
| response = agent.get_chat_response( | |
| request.message, | |
| llm_sampling_settings=settings, | |
| chat_history=messages, | |
| print_output=False, | |
| ) | |
| return {"response": response} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |