Spaces:
Paused
Paused
Commit
·
cfaa883
1
Parent(s):
19b1be5
Refactor v2
Browse files- .gitignore +1 -0
- Dockerfile +18 -41
- main/__init__.py +0 -0
- main/env_template +0 -55
- main/main.py +0 -61
- main/routes.py +0 -419
- requirements.txt +35 -45
.gitignore
CHANGED
|
@@ -42,3 +42,4 @@ wheels/
|
|
| 42 |
# Logs
|
| 43 |
*.log
|
| 44 |
logs/
|
|
|
|
|
|
| 42 |
# Logs
|
| 43 |
*.log
|
| 44 |
logs/
|
| 45 |
+
.cache/
|
Dockerfile
CHANGED
|
@@ -1,56 +1,33 @@
|
|
| 1 |
-
#
|
| 2 |
-
FROM
|
| 3 |
|
| 4 |
# Set working directory
|
| 5 |
-
WORKDIR /
|
| 6 |
|
| 7 |
-
# Install
|
| 8 |
-
RUN apt-get update && \
|
| 9 |
-
|
|
|
|
| 10 |
git \
|
| 11 |
-
wget \
|
| 12 |
-
&& apt-get clean \
|
| 13 |
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
|
| 15 |
-
# Create and set permissions for directories
|
| 16 |
-
RUN mkdir -p /app/.cache/huggingface && \
|
| 17 |
-
chmod 777 /app/.cache/huggingface && \
|
| 18 |
-
mkdir -p /app/.git && \
|
| 19 |
-
chmod 777 /app/.git
|
| 20 |
-
|
| 21 |
-
# Set environment variables
|
| 22 |
-
ENV TRANSFORMERS_CACHE=/app/.cache/huggingface/hub
|
| 23 |
-
ENV HF_HOME=/app/.cache/huggingface
|
| 24 |
-
ENV GIT_CONFIG_GLOBAL=/app/.git/config
|
| 25 |
-
|
| 26 |
# Copy requirements first to leverage Docker cache
|
| 27 |
COPY requirements.txt .
|
| 28 |
|
| 29 |
# Install Python dependencies
|
| 30 |
-
RUN
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
litgpt download mistralai/Mistral-7B-Instruct-v0.3 \
|
| 41 |
-
--access_token ${HF_TOKEN} \
|
| 42 |
-
--checkpoint_dir /app/main/checkpoints || { echo "Download failed with status $?"; exit 1; }
|
| 43 |
-
|
| 44 |
-
# Copy the rest of the application
|
| 45 |
-
COPY . .
|
| 46 |
-
|
| 47 |
-
# Set environment variables for the application
|
| 48 |
-
ENV LLM_ENGINE_HOST=0.0.0.0
|
| 49 |
-
ENV LLM_ENGINE_PORT=7860
|
| 50 |
-
ENV MODEL_PATH=/app/main/checkpoints/mistralai/Mistral-7B-Instruct-v0.3
|
| 51 |
|
| 52 |
-
# Expose port
|
| 53 |
-
EXPOSE
|
| 54 |
|
| 55 |
# Command to run the application
|
| 56 |
-
CMD ["
|
|
|
|
| 1 |
+
# Start from NVIDIA CUDA base image
|
| 2 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
| 3 |
|
| 4 |
# Set working directory
|
| 5 |
+
WORKDIR /code
|
| 6 |
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
python3.12 \
|
| 10 |
+
python3-pip \
|
| 11 |
git \
|
|
|
|
|
|
|
| 12 |
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# Copy requirements first to leverage Docker cache
|
| 15 |
COPY requirements.txt .
|
| 16 |
|
| 17 |
# Install Python dependencies
|
| 18 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
| 19 |
|
| 20 |
+
# Copy the application code
|
| 21 |
+
COPY ./app /code/app
|
| 22 |
+
COPY ./utils /code/utils
|
| 23 |
|
| 24 |
+
# Set environment variables
|
| 25 |
+
ENV PYTHONPATH=/code
|
| 26 |
+
ENV TRANSFORMERS_CACHE=/code/app/.cache
|
| 27 |
+
ENV CUDA_VISIBLE_DEVICES=0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# Expose the port the app runs on
|
| 30 |
+
EXPOSE 8000
|
| 31 |
|
| 32 |
# Command to run the application
|
| 33 |
+
CMD ["python3", "-m", "app.main"]
|
main/__init__.py
DELETED
|
File without changes
|
main/env_template
DELETED
|
@@ -1,55 +0,0 @@
|
|
| 1 |
-
# Service URLs Configuration
|
| 2 |
-
LLM_ENGINE_URL=http://localhost:8001
|
| 3 |
-
RAG_ENGINE_URL=http://localhost:8002
|
| 4 |
-
|
| 5 |
-
# LLM Engine Server Configuration
|
| 6 |
-
LLM_ENGINE_HOST=0.0.0.0
|
| 7 |
-
LLM_ENGINE_PORT=8001
|
| 8 |
-
|
| 9 |
-
# RAG Engine Server Configuration (if running locally)
|
| 10 |
-
RAG_ENGINE_HOST=0.0.0.0
|
| 11 |
-
RAG_ENGINE_PORT=8002
|
| 12 |
-
|
| 13 |
-
# Base Paths Configuration
|
| 14 |
-
BAS_MODEL_PATH=/path/to/your/model
|
| 15 |
-
BAS_RESOURCES=/path/to/resources
|
| 16 |
-
|
| 17 |
-
# CUDA Memory Management
|
| 18 |
-
PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128,garbage_collection_threshold:0.8,expandable_segments:True
|
| 19 |
-
|
| 20 |
-
# Other memory-related settings
|
| 21 |
-
CUDA_LAUNCH_BLOCKING=0
|
| 22 |
-
CUDA_VISIBLE_DEVICES=0
|
| 23 |
-
|
| 24 |
-
# Logging Configuration
|
| 25 |
-
LOG_LEVEL=INFO # DEBUG, INFO, WARNING, ERROR, CRITICAL
|
| 26 |
-
|
| 27 |
-
# GPU Configuration (optional)
|
| 28 |
-
# CUDA_VISIBLE_DEVICES=0,1 # Specify which GPUs to use
|
| 29 |
-
|
| 30 |
-
# Memory Configuration (optional)
|
| 31 |
-
# MAX_GPU_MEMORY=16Gi # Maximum GPU memory to use
|
| 32 |
-
# MAX_CPU_MEMORY=32Gi # Maximum CPU memory to use
|
| 33 |
-
|
| 34 |
-
# Security (if needed)
|
| 35 |
-
# API_KEY=your-api-key-here
|
| 36 |
-
# SSL_CERT_PATH=/path/to/cert
|
| 37 |
-
# SSL_KEY_PATH=/path/to/key
|
| 38 |
-
|
| 39 |
-
# Development Settings
|
| 40 |
-
# DEBUG=True # Enable debug mode
|
| 41 |
-
# RELOAD=False # Enable auto-reload for development
|
| 42 |
-
|
| 43 |
-
# Model Default Parameters (optional)
|
| 44 |
-
# DEFAULT_MAX_NEW_TOKENS=50
|
| 45 |
-
# DEFAULT_TEMPERATURE=1.0
|
| 46 |
-
# DEFAULT_TOP_K=50
|
| 47 |
-
# DEFAULT_TOP_P=1.0
|
| 48 |
-
|
| 49 |
-
# Cache Settings (optional)
|
| 50 |
-
# CACHE_DIR=/path/to/cache
|
| 51 |
-
# MAX_CACHE_SIZE=10Gi
|
| 52 |
-
|
| 53 |
-
# Monitoring (optional)
|
| 54 |
-
# ENABLE_METRICS=True
|
| 55 |
-
# PROMETHEUS_PORT=9090
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/main.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
-
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
-
import logging
|
| 4 |
-
import os
|
| 5 |
-
import uvicorn
|
| 6 |
-
from .routes import router
|
| 7 |
-
|
| 8 |
-
# Set up logging
|
| 9 |
-
logging.basicConfig(level=logging.INFO)
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
# Initialize FastAPI with simplified configuration
|
| 13 |
-
app = FastAPI(
|
| 14 |
-
title="LLM Engine Service",
|
| 15 |
-
docs_url="/docs",
|
| 16 |
-
redoc_url="/redoc",
|
| 17 |
-
openapi_url="/openapi.json"
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
# Add CORS middleware
|
| 21 |
-
app.add_middleware(
|
| 22 |
-
CORSMiddleware,
|
| 23 |
-
allow_origins=["*"],
|
| 24 |
-
allow_credentials=True,
|
| 25 |
-
allow_methods=["*"],
|
| 26 |
-
allow_headers=["*"],
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
# Include the router from routes.py
|
| 30 |
-
app.include_router(router)
|
| 31 |
-
|
| 32 |
-
def main():
|
| 33 |
-
# Load environment variables or configuration here
|
| 34 |
-
host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0")
|
| 35 |
-
port = int(os.getenv("LLM_ENGINE_PORT", "7860")) # Default to 7860 for Spaces
|
| 36 |
-
|
| 37 |
-
# Log startup information
|
| 38 |
-
logger.info(f"Starting LLM Engine service on {host}:{port}, or: ")
|
| 39 |
-
logger.info("Available endpoints:")
|
| 40 |
-
logger.info(" - /")
|
| 41 |
-
logger.info(" - /health")
|
| 42 |
-
logger.info(" - /models")
|
| 43 |
-
logger.info(" - /initialize")
|
| 44 |
-
logger.info(" - /generate")
|
| 45 |
-
logger.info(" - /generate/stream")
|
| 46 |
-
logger.info(" - /download")
|
| 47 |
-
logger.info(" - /convert")
|
| 48 |
-
logger.info(" - /docs")
|
| 49 |
-
logger.info(" - /redoc")
|
| 50 |
-
logger.info(" - /openapi.json")
|
| 51 |
-
|
| 52 |
-
# Start the server
|
| 53 |
-
uvicorn.run(
|
| 54 |
-
app,
|
| 55 |
-
host=host,
|
| 56 |
-
port=port,
|
| 57 |
-
log_level="info"
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
if __name__ == "__main__":
|
| 61 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/routes.py
DELETED
|
@@ -1,419 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
from fastapi import APIRouter, HTTPException
|
| 3 |
-
from fastapi.responses import StreamingResponse
|
| 4 |
-
from pydantic import BaseModel, Field
|
| 5 |
-
from typing import Optional, Union, AsyncGenerator, List
|
| 6 |
-
import torch
|
| 7 |
-
import logging
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from litgpt.api import LLM
|
| 10 |
-
from litgpt.scripts.download import download_from_hub
|
| 11 |
-
from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint
|
| 12 |
-
import json
|
| 13 |
-
import asyncio
|
| 14 |
-
|
| 15 |
-
# Set up logging
|
| 16 |
-
logger = logging.getLogger(__name__)
|
| 17 |
-
|
| 18 |
-
# Create router instance
|
| 19 |
-
router = APIRouter()
|
| 20 |
-
|
| 21 |
-
# Global variable to store the LLM instance
|
| 22 |
-
llm_instance = None
|
| 23 |
-
|
| 24 |
-
class InitializeRequest(BaseModel):
|
| 25 |
-
"""Configuration for model initialization including model path"""
|
| 26 |
-
mode: str = Field(default="cpu", description="Execution mode ('cpu' or 'gpu')")
|
| 27 |
-
precision: Optional[str] = Field(None, description="Precision format (e.g., 'bf16-true', 'bf16-mixed')")
|
| 28 |
-
quantize: Optional[str] = Field(None, description="Quantization format (e.g., 'bnb.nf4')")
|
| 29 |
-
gpu_count: Union[str, int] = Field(default="auto", description="Number of GPUs to use or 'auto'")
|
| 30 |
-
model_path: str = Field(..., description="Path to the model relative to checkpoints directory")
|
| 31 |
-
|
| 32 |
-
class GenerateRequest(BaseModel):
|
| 33 |
-
"""Request parameters for text generation"""
|
| 34 |
-
prompt: str = Field(..., description="Input text prompt for generation")
|
| 35 |
-
max_new_tokens: int = Field(default=50, description="Maximum number of tokens to generate")
|
| 36 |
-
temperature: float = Field(default=1.0, description="Sampling temperature")
|
| 37 |
-
top_k: Optional[int] = Field(None, description="Top-k sampling parameter")
|
| 38 |
-
top_p: float = Field(default=1.0, description="Top-p sampling parameter")
|
| 39 |
-
return_as_token_ids: bool = Field(default=False, description="Whether to return token IDs instead of text")
|
| 40 |
-
stream: bool = Field(default=False, description="Whether to stream the response")
|
| 41 |
-
|
| 42 |
-
class StreamGenerateRequest(BaseModel):
|
| 43 |
-
"""Request parameters for streaming text generation"""
|
| 44 |
-
prompt: str = Field(..., description="Input text prompt for generation")
|
| 45 |
-
max_new_tokens: int = Field(default=50, description="Maximum number of tokens to generate")
|
| 46 |
-
temperature: float = Field(default=1.0, description="Sampling temperature")
|
| 47 |
-
top_k: Optional[int] = Field(None, description="Top-k sampling parameter")
|
| 48 |
-
top_p: float = Field(default=1.0, description="Top-p sampling parameter")
|
| 49 |
-
|
| 50 |
-
class DownloadModelRequest(BaseModel):
|
| 51 |
-
"""Request to download a model from HuggingFace"""
|
| 52 |
-
repo_id: str = Field(
|
| 53 |
-
...,
|
| 54 |
-
description="HuggingFace repository ID (e.g., 'huihui-ai/Llama-3.2-3B-Instruct-abliterated')"
|
| 55 |
-
)
|
| 56 |
-
model_name: str = Field(
|
| 57 |
-
...,
|
| 58 |
-
description="Model architecture name (e.g., 'Llama-3.2-3B-Instruct')"
|
| 59 |
-
)
|
| 60 |
-
access_token: Optional[str] = Field(
|
| 61 |
-
None,
|
| 62 |
-
description="HuggingFace access token for private models"
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
class ConvertModelRequest(BaseModel):
|
| 66 |
-
"""Request to convert a downloaded model"""
|
| 67 |
-
folder_path: str = Field(
|
| 68 |
-
...,
|
| 69 |
-
description="Path relative to checkpoints where model was downloaded"
|
| 70 |
-
)
|
| 71 |
-
model_name: str = Field(
|
| 72 |
-
...,
|
| 73 |
-
description="Model architecture name for conversion"
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
class ModelResponse(BaseModel):
|
| 77 |
-
"""Model information response"""
|
| 78 |
-
name: str = Field(..., description="Full model name including organization")
|
| 79 |
-
path: str = Field(..., description="Relative path in checkpoints directory")
|
| 80 |
-
downloaded: bool = Field(..., description="Whether the model files are downloaded")
|
| 81 |
-
converted: bool = Field(..., description="Whether the model is converted to LitGPT format")
|
| 82 |
-
has_safetensors: bool = Field(..., description="Whether safetensors files are present")
|
| 83 |
-
files: List[str] = Field(..., description="List of files in model directory")
|
| 84 |
-
|
| 85 |
-
class ModelsListResponse(BaseModel):
|
| 86 |
-
"""Response for listing models"""
|
| 87 |
-
models: List[ModelResponse] = Field(..., description="List of available models")
|
| 88 |
-
|
| 89 |
-
@router.post(
|
| 90 |
-
"/download",
|
| 91 |
-
response_model=dict,
|
| 92 |
-
summary="Download a model from HuggingFace Hub",
|
| 93 |
-
description="Downloads a model from HuggingFace to the LLM Engine's checkpoints directory",
|
| 94 |
-
response_description="Download status and location information"
|
| 95 |
-
)
|
| 96 |
-
async def download_model(request: DownloadModelRequest):
|
| 97 |
-
"""
|
| 98 |
-
Download a model from HuggingFace Hub.
|
| 99 |
-
|
| 100 |
-
- Downloads model files to the checkpoints directory
|
| 101 |
-
- Creates necessary subdirectories
|
| 102 |
-
- Handles authentication for private models
|
| 103 |
-
|
| 104 |
-
Returns:
|
| 105 |
-
A JSON object containing download status and path information
|
| 106 |
-
"""
|
| 107 |
-
try:
|
| 108 |
-
# Get the project root directory and construct paths
|
| 109 |
-
project_root = Path(__file__).parent.parent
|
| 110 |
-
checkpoints_dir = project_root / "checkpoints"
|
| 111 |
-
logger.info(f"Downloading model {request.repo_id} to {checkpoints_dir}")
|
| 112 |
-
|
| 113 |
-
download_from_hub(
|
| 114 |
-
repo_id=request.repo_id,
|
| 115 |
-
model_name=request.model_name,
|
| 116 |
-
access_token=request.access_token,
|
| 117 |
-
checkpoint_dir=checkpoints_dir,
|
| 118 |
-
tokenizer_only=False
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
return {
|
| 122 |
-
"status": "success",
|
| 123 |
-
"message": f"Model downloaded to {checkpoints_dir / request.repo_id}",
|
| 124 |
-
"path": str(request.repo_id)
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
except Exception as e:
|
| 128 |
-
logger.error(f"Error downloading model: {str(e)}")
|
| 129 |
-
raise HTTPException(status_code=500, detail=f"Error downloading model: {str(e)}")
|
| 130 |
-
|
| 131 |
-
@router.post(
|
| 132 |
-
"/convert",
|
| 133 |
-
response_model=dict,
|
| 134 |
-
summary="Convert a model to LitGPT format",
|
| 135 |
-
description="Converts a downloaded model to the LitGPT format required for inference",
|
| 136 |
-
response_description="Conversion status and location information"
|
| 137 |
-
)
|
| 138 |
-
async def convert_model(request: ConvertModelRequest):
|
| 139 |
-
"""
|
| 140 |
-
Convert a downloaded model to LitGPT format.
|
| 141 |
-
|
| 142 |
-
- Converts model files to LitGPT's format
|
| 143 |
-
- Creates lit_model.pth file
|
| 144 |
-
- Maintains original files
|
| 145 |
-
|
| 146 |
-
Returns:
|
| 147 |
-
A JSON object containing conversion status and path information
|
| 148 |
-
"""
|
| 149 |
-
try:
|
| 150 |
-
project_root = Path(__file__).parent.parent
|
| 151 |
-
checkpoints_dir = project_root / "checkpoints"
|
| 152 |
-
model_dir = checkpoints_dir / request.folder_path
|
| 153 |
-
|
| 154 |
-
if not model_dir.exists():
|
| 155 |
-
raise HTTPException(
|
| 156 |
-
status_code=404,
|
| 157 |
-
detail=f"Model directory not found: {request.folder_path}"
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
logger.info(f"Converting model in {model_dir}")
|
| 161 |
-
convert_hf_checkpoint(
|
| 162 |
-
checkpoint_dir=model_dir,
|
| 163 |
-
model_name=request.model_name
|
| 164 |
-
)
|
| 165 |
-
|
| 166 |
-
return {
|
| 167 |
-
"status": "success",
|
| 168 |
-
"message": f"Model converted successfully",
|
| 169 |
-
"path": str(request.folder_path)
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
except Exception as e:
|
| 173 |
-
logger.error(f"Error converting model: {str(e)}")
|
| 174 |
-
raise HTTPException(status_code=500, detail=f"Error converting model: {str(e)}")
|
| 175 |
-
|
| 176 |
-
@router.get(
|
| 177 |
-
"/models",
|
| 178 |
-
response_model=ModelsListResponse,
|
| 179 |
-
summary="List available models",
|
| 180 |
-
description="Lists all models in the checkpoints directory with their status",
|
| 181 |
-
response_description="List of models with their details and status"
|
| 182 |
-
)
|
| 183 |
-
async def list_models():
|
| 184 |
-
"""
|
| 185 |
-
List all models in the checkpoints directory.
|
| 186 |
-
|
| 187 |
-
Returns:
|
| 188 |
-
A JSON object containing:
|
| 189 |
-
- List of models
|
| 190 |
-
- Each model's download status
|
| 191 |
-
- Each model's conversion status
|
| 192 |
-
- Available files for each model
|
| 193 |
-
"""
|
| 194 |
-
try:
|
| 195 |
-
project_root = Path(__file__).parent.parent
|
| 196 |
-
checkpoints_dir = project_root / "checkpoints"
|
| 197 |
-
models = []
|
| 198 |
-
|
| 199 |
-
if checkpoints_dir.exists():
|
| 200 |
-
for org_dir in checkpoints_dir.iterdir():
|
| 201 |
-
if org_dir.is_dir():
|
| 202 |
-
for model_dir in org_dir.iterdir():
|
| 203 |
-
if model_dir.is_dir():
|
| 204 |
-
files = [f.name for f in model_dir.iterdir()]
|
| 205 |
-
has_safetensors = any(f.endswith('.safetensors') for f in files)
|
| 206 |
-
has_lit_model = 'lit_model.pth' in files
|
| 207 |
-
|
| 208 |
-
model_info = ModelResponse(
|
| 209 |
-
name=f"{org_dir.name}/{model_dir.name}",
|
| 210 |
-
path=str(model_dir.relative_to(checkpoints_dir)),
|
| 211 |
-
downloaded=True,
|
| 212 |
-
converted=has_lit_model,
|
| 213 |
-
has_safetensors=has_safetensors,
|
| 214 |
-
files=files
|
| 215 |
-
)
|
| 216 |
-
models.append(model_info)
|
| 217 |
-
|
| 218 |
-
return ModelsListResponse(models=models)
|
| 219 |
-
|
| 220 |
-
except Exception as e:
|
| 221 |
-
logger.error(f"Error listing models: {str(e)}")
|
| 222 |
-
raise HTTPException(status_code=500, detail=f"Error listing models: {str(e)}")
|
| 223 |
-
|
| 224 |
-
@router.post("/initialize")
|
| 225 |
-
async def initialize_model(request: InitializeRequest):
|
| 226 |
-
"""
|
| 227 |
-
Initialize the LLM model with specified configuration.
|
| 228 |
-
"""
|
| 229 |
-
global llm_instance
|
| 230 |
-
|
| 231 |
-
try:
|
| 232 |
-
# Get the project root directory (where main.py is located)
|
| 233 |
-
project_root = Path(__file__).parent.parent
|
| 234 |
-
checkpoints_dir = project_root / "checkpoints"
|
| 235 |
-
logger.info(f"Checkpoint dir is: {checkpoints_dir}")
|
| 236 |
-
|
| 237 |
-
# For LitGPT downloaded models, path includes organization
|
| 238 |
-
if "/" in request.model_path:
|
| 239 |
-
# e.g., "mistralai/Mistral-7B-Instruct-v0.3"
|
| 240 |
-
org, model_name = request.model_path.split("/")
|
| 241 |
-
model_path = str(checkpoints_dir / org / model_name)
|
| 242 |
-
else:
|
| 243 |
-
# Fallback for direct model paths
|
| 244 |
-
model_path = str(checkpoints_dir / request.model_path)
|
| 245 |
-
|
| 246 |
-
logger.info(f"Using model path: {model_path}")
|
| 247 |
-
|
| 248 |
-
# Load the model
|
| 249 |
-
llm_instance = LLM.load(
|
| 250 |
-
model=model_path,
|
| 251 |
-
distribute=None if request.precision or request.quantize else "auto"
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
# If manual distribution is needed
|
| 255 |
-
logger.info("Distributing model")
|
| 256 |
-
if request.precision or request.quantize:
|
| 257 |
-
llm_instance.distribute(
|
| 258 |
-
accelerator="cuda" if request.mode == "gpu" else "cpu",
|
| 259 |
-
devices=request.gpu_count,
|
| 260 |
-
precision=request.precision,
|
| 261 |
-
quantize=request.quantize
|
| 262 |
-
)
|
| 263 |
-
|
| 264 |
-
logger.info(
|
| 265 |
-
f"Model initialized successfully with config:\n"
|
| 266 |
-
f"Mode: {request.mode}\n"
|
| 267 |
-
f"Precision: {request.precision}\n"
|
| 268 |
-
f"Quantize: {request.quantize}\n"
|
| 269 |
-
f"GPU Count: {request.gpu_count}\n"
|
| 270 |
-
f"Model Path: {model_path}\n"
|
| 271 |
-
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
|
| 272 |
-
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
return {"success": True, "message": "Model initialized successfully"}
|
| 276 |
-
|
| 277 |
-
except Exception as e:
|
| 278 |
-
logger.error(f"Error initializing model: {str(e)}")
|
| 279 |
-
# Print detailed memory statistics on failure
|
| 280 |
-
logger.error(f"GPU Memory Stats:\n"
|
| 281 |
-
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
|
| 282 |
-
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
|
| 283 |
-
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
|
| 284 |
-
raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")
|
| 285 |
-
|
| 286 |
-
@router.post("/generate")
|
| 287 |
-
async def generate(request: GenerateRequest):
|
| 288 |
-
"""
|
| 289 |
-
Generate text using the initialized model.
|
| 290 |
-
"""
|
| 291 |
-
global llm_instance
|
| 292 |
-
|
| 293 |
-
if llm_instance is None:
|
| 294 |
-
raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")
|
| 295 |
-
|
| 296 |
-
try:
|
| 297 |
-
if request.stream:
|
| 298 |
-
raise HTTPException(
|
| 299 |
-
status_code=400,
|
| 300 |
-
detail="Streaming is not currently supported through the API"
|
| 301 |
-
)
|
| 302 |
-
|
| 303 |
-
generated_text = llm_instance.generate(
|
| 304 |
-
prompt=request.prompt,
|
| 305 |
-
max_new_tokens=request.max_new_tokens,
|
| 306 |
-
temperature=request.temperature,
|
| 307 |
-
top_k=request.top_k,
|
| 308 |
-
top_p=request.top_p,
|
| 309 |
-
return_as_token_ids=request.return_as_token_ids,
|
| 310 |
-
stream=False # Force stream to False for now
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
response = {
|
| 314 |
-
"generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
|
| 315 |
-
"metadata": {
|
| 316 |
-
"prompt": request.prompt,
|
| 317 |
-
"max_new_tokens": request.max_new_tokens,
|
| 318 |
-
"temperature": request.temperature,
|
| 319 |
-
"top_k": request.top_k,
|
| 320 |
-
"top_p": request.top_p
|
| 321 |
-
}
|
| 322 |
-
}
|
| 323 |
-
|
| 324 |
-
return response
|
| 325 |
-
|
| 326 |
-
except Exception as e:
|
| 327 |
-
logger.error(f"Error generating text: {str(e)}")
|
| 328 |
-
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
|
| 329 |
-
|
| 330 |
-
@router.post("/generate/stream")
|
| 331 |
-
async def generate_stream(request: StreamGenerateRequest):
|
| 332 |
-
"""
|
| 333 |
-
Generate text using the initialized model with streaming response.
|
| 334 |
-
Returns a StreamingResponse that yields JSON-formatted chunks of text.
|
| 335 |
-
"""
|
| 336 |
-
global llm_instance
|
| 337 |
-
|
| 338 |
-
if llm_instance is None:
|
| 339 |
-
raise HTTPException(
|
| 340 |
-
status_code=400,
|
| 341 |
-
detail="Model not initialized. Call /initialize first."
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
-
async def event_generator() -> AsyncGenerator[str, None]:
|
| 345 |
-
try:
|
| 346 |
-
# Start the generation with streaming enabled
|
| 347 |
-
for token in llm_instance.generate(
|
| 348 |
-
prompt=request.prompt,
|
| 349 |
-
max_new_tokens=request.max_new_tokens,
|
| 350 |
-
temperature=request.temperature,
|
| 351 |
-
top_k=request.top_k,
|
| 352 |
-
top_p=request.top_p,
|
| 353 |
-
stream=True # Enable streaming
|
| 354 |
-
):
|
| 355 |
-
# Create a JSON response for each token
|
| 356 |
-
chunk = {
|
| 357 |
-
"token": token,
|
| 358 |
-
"metadata": {
|
| 359 |
-
"prompt": request.prompt,
|
| 360 |
-
"is_finished": False
|
| 361 |
-
}
|
| 362 |
-
}
|
| 363 |
-
# Format as SSE data
|
| 364 |
-
yield f"data: {json.dumps(chunk)}\n\n"
|
| 365 |
-
|
| 366 |
-
# Small delay to prevent overwhelming the client
|
| 367 |
-
await asyncio.sleep(0.01)
|
| 368 |
-
|
| 369 |
-
# Send final message indicating completion
|
| 370 |
-
final_chunk = {
|
| 371 |
-
"token": "",
|
| 372 |
-
"metadata": {
|
| 373 |
-
"prompt": request.prompt,
|
| 374 |
-
"is_finished": True
|
| 375 |
-
}
|
| 376 |
-
}
|
| 377 |
-
yield f"data: {json.dumps(final_chunk)}\n\n"
|
| 378 |
-
|
| 379 |
-
except Exception as e:
|
| 380 |
-
logger.error(f"Error in stream generation: {str(e)}")
|
| 381 |
-
error_chunk = {
|
| 382 |
-
"error": str(e),
|
| 383 |
-
"metadata": {
|
| 384 |
-
"prompt": request.prompt,
|
| 385 |
-
"is_finished": True
|
| 386 |
-
}
|
| 387 |
-
}
|
| 388 |
-
yield f"data: {json.dumps(error_chunk)}\n\n"
|
| 389 |
-
|
| 390 |
-
return StreamingResponse(
|
| 391 |
-
event_generator(),
|
| 392 |
-
media_type="text/event-stream",
|
| 393 |
-
headers={
|
| 394 |
-
'Cache-Control': 'no-cache',
|
| 395 |
-
'Connection': 'keep-alive',
|
| 396 |
-
}
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
@router.get("/health")
|
| 400 |
-
async def health_check():
|
| 401 |
-
"""
|
| 402 |
-
Check if the service is running and model is loaded.
|
| 403 |
-
Returns status information including model details if loaded.
|
| 404 |
-
"""
|
| 405 |
-
global llm_instance
|
| 406 |
-
|
| 407 |
-
status = {
|
| 408 |
-
"status": "healthy",
|
| 409 |
-
"model_loaded": llm_instance is not None,
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
if llm_instance is not None:
|
| 413 |
-
logger.info(f"llm_instance is: {llm_instance}")
|
| 414 |
-
status["model_info"] = {
|
| 415 |
-
"model_path": llm_instance.config.name,
|
| 416 |
-
"device": str(next(llm_instance.model.parameters()).device)
|
| 417 |
-
}
|
| 418 |
-
|
| 419 |
-
return status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,67 +1,57 @@
|
|
| 1 |
-
|
| 2 |
-
aiohttp==3.10.10
|
| 3 |
-
aiosignal==1.3.1
|
| 4 |
annotated-types==0.7.0
|
| 5 |
anyio==4.6.2.post1
|
| 6 |
-
attrs==24.2.0
|
| 7 |
bitsandbytes==0.44.1
|
| 8 |
certifi==2024.8.30
|
| 9 |
charset-normalizer==3.4.0
|
| 10 |
click==8.1.7
|
| 11 |
-
|
| 12 |
-
fastapi==0.109.0
|
| 13 |
filelock==3.16.1
|
| 14 |
-
frozenlist==1.5.0
|
| 15 |
fsspec==2024.10.0
|
| 16 |
h11==0.14.0
|
| 17 |
-
huggingface-hub==0.
|
| 18 |
idna==3.10
|
| 19 |
-
|
| 20 |
Jinja2==3.1.4
|
| 21 |
-
jsonargparse==4.32.1
|
| 22 |
-
lightning==2.4.0
|
| 23 |
-
lightning-utilities==0.11.8
|
| 24 |
-
litgpt==0.5.3
|
| 25 |
MarkupSafe==3.0.2
|
| 26 |
mpmath==1.3.0
|
| 27 |
-
multidict==6.1.0
|
| 28 |
networkx==3.4.2
|
| 29 |
-
numpy==1.
|
| 30 |
-
nvidia-cublas-cu12==12.
|
| 31 |
-
nvidia-cuda-cupti-cu12==12.
|
| 32 |
-
nvidia-cuda-nvrtc-cu12==12.
|
| 33 |
-
nvidia-cuda-runtime-cu12==12.
|
| 34 |
nvidia-cudnn-cu12==9.1.0.70
|
| 35 |
-
nvidia-cufft-cu12==11.
|
| 36 |
-
nvidia-curand-cu12==10.3.
|
| 37 |
-
nvidia-cusolver-cu12==11.
|
| 38 |
-
nvidia-cusparse-cu12==12.1.
|
| 39 |
-
nvidia-nccl-cu12==2.
|
| 40 |
-
nvidia-nvjitlink-cu12==12.
|
| 41 |
-
nvidia-nvtx-cu12==12.
|
| 42 |
-
packaging==24.
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
| 48 |
PyYAML==6.0.2
|
| 49 |
-
regex==2024.
|
| 50 |
requests==2.32.3
|
|
|
|
| 51 |
safetensors==0.4.5
|
| 52 |
-
setuptools==75.
|
| 53 |
sniffio==1.3.1
|
| 54 |
-
starlette==0.
|
| 55 |
-
sympy==1.13.
|
| 56 |
tokenizers==0.20.3
|
| 57 |
-
torch==2.
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
triton==3.0.0
|
| 62 |
-
typeshed_client==2.7.0
|
| 63 |
typing_extensions==4.12.2
|
| 64 |
urllib3==2.2.3
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 1 |
+
accelerate==1.1.1
|
|
|
|
|
|
|
| 2 |
annotated-types==0.7.0
|
| 3 |
anyio==4.6.2.post1
|
|
|
|
| 4 |
bitsandbytes==0.44.1
|
| 5 |
certifi==2024.8.30
|
| 6 |
charset-normalizer==3.4.0
|
| 7 |
click==8.1.7
|
| 8 |
+
fastapi==0.115.5
|
|
|
|
| 9 |
filelock==3.16.1
|
|
|
|
| 10 |
fsspec==2024.10.0
|
| 11 |
h11==0.14.0
|
| 12 |
+
huggingface-hub==0.26.3
|
| 13 |
idna==3.10
|
| 14 |
+
inquirerpy==0.3.4
|
| 15 |
Jinja2==3.1.4
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
MarkupSafe==3.0.2
|
| 17 |
mpmath==1.3.0
|
|
|
|
| 18 |
networkx==3.4.2
|
| 19 |
+
numpy==2.1.3
|
| 20 |
+
nvidia-cublas-cu12==12.4.5.8
|
| 21 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
| 22 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
| 23 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
| 24 |
nvidia-cudnn-cu12==9.1.0.70
|
| 25 |
+
nvidia-cufft-cu12==11.2.1.3
|
| 26 |
+
nvidia-curand-cu12==10.3.5.147
|
| 27 |
+
nvidia-cusolver-cu12==11.6.1.9
|
| 28 |
+
nvidia-cusparse-cu12==12.3.1.170
|
| 29 |
+
nvidia-nccl-cu12==2.21.5
|
| 30 |
+
nvidia-nvjitlink-cu12==12.4.127
|
| 31 |
+
nvidia-nvtx-cu12==12.4.127
|
| 32 |
+
packaging==24.2
|
| 33 |
+
pfzy==0.3.4
|
| 34 |
+
prompt_toolkit==3.0.48
|
| 35 |
+
psutil==6.1.0
|
| 36 |
+
pydantic==2.10.2
|
| 37 |
+
pydantic_core==2.27.1
|
| 38 |
+
python-dotenv==1.0.1
|
| 39 |
PyYAML==6.0.2
|
| 40 |
+
regex==2024.11.6
|
| 41 |
requests==2.32.3
|
| 42 |
+
router==0.1
|
| 43 |
safetensors==0.4.5
|
| 44 |
+
setuptools==75.6.0
|
| 45 |
sniffio==1.3.1
|
| 46 |
+
starlette==0.41.3
|
| 47 |
+
sympy==1.13.1
|
| 48 |
tokenizers==0.20.3
|
| 49 |
+
torch==2.5.1
|
| 50 |
+
tqdm==4.67.1
|
| 51 |
+
transformers==4.46.3
|
| 52 |
+
triton==3.1.0
|
|
|
|
|
|
|
| 53 |
typing_extensions==4.12.2
|
| 54 |
urllib3==2.2.3
|
| 55 |
+
utils==1.0.2
|
| 56 |
+
uvicorn==0.32.1
|
| 57 |
+
wcwidth==0.2.13
|