aneeb15's picture
Initial release of Auto-FineTune-Ops
d4398e6
"""
FastAPI Deployment Server
==========================
One-click deployment bridge for fine-tuned models.
"""
import os
from pathlib import Path
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from datetime import datetime
from rich.console import Console
console = Console()
@dataclass
class GenerationRequest:
"""Request model for text generation."""
prompt: str
system_prompt: Optional[str] = None
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
stream: bool = False
@dataclass
class GenerationResponse:
"""Response model for text generation."""
generated_text: str
prompt: str
model: str
tokens_generated: int
generation_time: float
class DeploymentServer:
"""
FastAPI-based deployment server for fine-tuned models.
Features:
- RESTful API for inference
- Health check endpoint
- Batch generation support
- Automatic model loading
"""
def __init__(
self,
model_path: str,
host: str = "0.0.0.0",
port: int = 8000,
max_seq_length: int = 2048
):
"""
Initialize the deployment server.
Args:
model_path: Path to the fine-tuned model
host: Server host
port: Server port
max_seq_length: Maximum sequence length
"""
self.model_path = model_path
self.host = host
self.port = port
self.max_seq_length = max_seq_length
self.model = None
self.tokenizer = None
self.app = None
def load_model(self):
"""Load the fine-tuned model."""
console.print(f"\n[bold blue]📂 Loading model from:[/] {self.model_path}")
try:
from unsloth import FastLanguageModel
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=self.model_path,
max_seq_length=self.max_seq_length,
dtype=None,
load_in_4bit=True,
)
FastLanguageModel.for_inference(self.model)
console.print("[green]✓ Model loaded successfully[/]")
except ImportError:
console.print("[yellow]⚠️ Unsloth not available, trying transformers...[/]")
from transformers import AutoModelForCausalLM, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
device_map="auto",
torch_dtype="auto"
)
console.print("[green]✓ Model loaded with transformers[/]")
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9
) -> GenerationResponse:
"""
Generate text from the model.
Args:
prompt: User prompt
system_prompt: Optional system prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling parameter
Returns:
GenerationResponse with generated text
"""
if self.model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
start_time = datetime.now()
# Format prompt with Alpaca template
if system_prompt:
formatted_prompt = f"""{system_prompt}
### Instruction:
{prompt}
### Response:
"""
else:
formatted_prompt = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{prompt}
### Response:
"""
# Tokenize
inputs = self.tokenizer(
formatted_prompt,
return_tensors="pt"
).to(self.model.device)
# Generate
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the generated part
if "### Response:" in full_response:
generated_text = full_response.split("### Response:")[-1].strip()
else:
generated_text = full_response[len(formatted_prompt):].strip()
generation_time = (datetime.now() - start_time).total_seconds()
tokens_generated = len(self.tokenizer.encode(generated_text))
return GenerationResponse(
generated_text=generated_text,
prompt=prompt,
model=self.model_path,
tokens_generated=tokens_generated,
generation_time=generation_time
)
def create_app(self):
"""Create the FastAPI application."""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI(
title="Auto-FineTune-Ops Inference API",
description="API for serving fine-tuned LLM models",
version="1.0.0"
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models for API
class GenerateRequest(BaseModel):
prompt: str
system_prompt: Optional[str] = None
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
class GenerateResponse(BaseModel):
generated_text: str
prompt: str
model: str
tokens_generated: int
generation_time: float
class BatchGenerateRequest(BaseModel):
prompts: List[str]
system_prompt: Optional[str] = None
max_tokens: int = 512
temperature: float = 0.7
top_p: float = 0.9
class HealthResponse(BaseModel):
status: str
model: str
model_loaded: bool
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint."""
return HealthResponse(
status="healthy",
model=self.model_path,
model_loaded=self.model is not None
)
@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
"""Generate text from a single prompt."""
if self.model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
result = self.generate(
prompt=request.prompt,
system_prompt=request.system_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
)
return GenerateResponse(
generated_text=result.generated_text,
prompt=result.prompt,
model=result.model,
tokens_generated=result.tokens_generated,
generation_time=result.generation_time
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate/batch", response_model=List[GenerateResponse])
async def batch_generate(request: BatchGenerateRequest):
"""Generate text from multiple prompts."""
if self.model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
results = []
for prompt in request.prompts:
try:
result = self.generate(
prompt=prompt,
system_prompt=request.system_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
)
results.append(GenerateResponse(
generated_text=result.generated_text,
prompt=result.prompt,
model=result.model,
tokens_generated=result.tokens_generated,
generation_time=result.generation_time
))
except Exception as e:
results.append(GenerateResponse(
generated_text=f"Error: {str(e)}",
prompt=prompt,
model=self.model_path,
tokens_generated=0,
generation_time=0.0
))
return results
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"name": "Auto-FineTune-Ops Inference API",
"version": "1.0.0",
"model": self.model_path,
"endpoints": {
"/health": "Health check",
"/generate": "Generate text (POST)",
"/generate/batch": "Batch generation (POST)"
}
}
self.app = app
return app
def run(self, reload: bool = False):
"""
Start the FastAPI server.
Args:
reload: Enable auto-reload for development
"""
import uvicorn
console.print("\n" + "="*60)
console.print("[bold magenta]🚀 DEPLOYMENT SERVER[/]")
console.print("="*60)
# Load model if not already loaded
if self.model is None:
self.load_model()
# Create app if not already created
if self.app is None:
self.create_app()
console.print(f"\n[bold green]Starting server at http://{self.host}:{self.port}[/]")
console.print("[dim]Press Ctrl+C to stop[/]\n")
uvicorn.run(
self.app,
host=self.host,
port=self.port,
reload=reload
)
def main():
"""CLI entry point for deployment."""
import argparse
parser = argparse.ArgumentParser(description="Deploy fine-tuned model as API")
parser.add_argument("--model", required=True, help="Path to fine-tuned model")
parser.add_argument("--host", default="0.0.0.0", help="Server host")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--reload", action="store_true", help="Enable auto-reload")
args = parser.parse_args()
server = DeploymentServer(
model_path=args.model,
host=args.host,
port=args.port
)
server.run(reload=args.reload)
if __name__ == "__main__":
main()