|
|
|
|
|
""" |
|
|
FastAPI server for serving Mistral 7B fine-tuned models |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from typing import Optional, Dict, Any |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
import uvicorn |
|
|
import sys |
|
|
from pathlib import Path |
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from inference.inference_mistral7b import load_local_model, generate_with_local_model, get_device_info |
|
|
import torch |
|
|
|
|
|
|
|
|
_MODEL_BASE = Path(__file__).parent.parent / "mistral7b-finetuned-ahb2apb" |
|
|
DEFAULT_MODEL_PATH = str(_MODEL_BASE) |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
device_info = None |
|
|
|
|
|
app = FastAPI( |
|
|
title="Mistral 7B AHB2APB API", |
|
|
description="API for serving the fine-tuned Mistral 7B model for AHB2APB conversion", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str |
|
|
max_length: Optional[int] = 512 |
|
|
temperature: Optional[float] = 0.7 |
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
response: str |
|
|
model: str |
|
|
max_length: int |
|
|
temperature: float |
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
model_loaded: bool |
|
|
device: str |
|
|
model_path: str |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
"""Load the model when the server starts""" |
|
|
global model, tokenizer, device_info |
|
|
|
|
|
model_path = os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH) |
|
|
|
|
|
print(f"Loading model from: {model_path}") |
|
|
print("=" * 70) |
|
|
|
|
|
try: |
|
|
device_info = get_device_info() |
|
|
model, tokenizer = load_local_model(model_path) |
|
|
print(f"\n✓ Model loaded successfully on {device_info['device']}!") |
|
|
print(f"✓ Server ready to accept requests") |
|
|
print("=" * 70) |
|
|
except Exception as e: |
|
|
print(f"\n✗ Error loading model: {e}") |
|
|
print("=" * 70) |
|
|
sys.exit(1) |
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
return HealthResponse( |
|
|
status="healthy" if model is not None else "error", |
|
|
model_loaded=model is not None, |
|
|
device=device_info["device"] if device_info else "unknown", |
|
|
model_path=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH) |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with API information""" |
|
|
return { |
|
|
"name": "Mistral 7B AHB2APB API", |
|
|
"version": "1.0.0", |
|
|
"status": "running", |
|
|
"model": os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH), |
|
|
"endpoints": { |
|
|
"health": "/health", |
|
|
"generate": "/api/generate", |
|
|
"docs": "/docs" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.post("/api/generate", response_model=GenerateResponse) |
|
|
async def generate(request: GenerateRequest): |
|
|
""" |
|
|
Generate text from a prompt using the fine-tuned model |
|
|
""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
response = generate_with_local_model( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=request.prompt, |
|
|
max_length=request.max_length or 512, |
|
|
temperature=request.temperature or 0.7 |
|
|
) |
|
|
|
|
|
return GenerateResponse( |
|
|
response=response, |
|
|
model=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH), |
|
|
max_length=request.max_length or 512, |
|
|
temperature=request.temperature or 0.7 |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") |
|
|
|
|
|
@app.post("/api/generate/batch") |
|
|
async def generate_batch(requests: list[GenerateRequest]): |
|
|
""" |
|
|
Generate text from multiple prompts (batch processing) |
|
|
""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
responses = [] |
|
|
for req in requests: |
|
|
response = generate_with_local_model( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=req.prompt, |
|
|
max_length=req.max_length or 512, |
|
|
temperature=req.temperature or 0.7 |
|
|
) |
|
|
responses.append({ |
|
|
"response": response, |
|
|
"prompt": req.prompt |
|
|
}) |
|
|
|
|
|
return {"results": responses} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Batch generation error: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Start Mistral 7B API server") |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
type=str, |
|
|
default=DEFAULT_MODEL_PATH, |
|
|
help=f"Path to fine-tuned model (default: {DEFAULT_MODEL_PATH})" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--host", |
|
|
type=str, |
|
|
default="0.0.0.0", |
|
|
help="Host to bind to (default: 0.0.0.0)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--port", |
|
|
type=int, |
|
|
default=8000, |
|
|
help="Port to bind to (default: 8000)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--reload", |
|
|
action="store_true", |
|
|
help="Enable auto-reload (for development)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--workers", |
|
|
type=int, |
|
|
default=1, |
|
|
help="Number of worker processes (default: 1)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
os.environ["MODEL_PATH"] = args.model_path |
|
|
|
|
|
print(f"\n🚀 Starting Mistral 7B AHB2APB API Server") |
|
|
print(f" Model: {args.model_path}") |
|
|
print(f" Host: {args.host}") |
|
|
print(f" Port: {args.port}") |
|
|
print(f" Workers: {args.workers}") |
|
|
print(f" Reload: {args.reload}\n") |
|
|
|
|
|
|
|
|
import os |
|
|
os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
uvicorn.run( |
|
|
"api_server:app", |
|
|
host=args.host, |
|
|
port=args.port, |
|
|
reload=args.reload, |
|
|
workers=1 if args.reload else args.workers |
|
|
) |
|
|
|
|
|
|