|
|
|
|
|
""" |
|
|
Secure FastAPI server for serving Mistral 7B fine-tuned models |
|
|
Includes API key authentication like commercial services |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import secrets |
|
|
from typing import Optional |
|
|
from fastapi import FastAPI, HTTPException, Header, Depends |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
from pydantic import BaseModel |
|
|
import uvicorn |
|
|
from pathlib import Path |
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from inference.inference_mistral7b import load_local_model, generate_with_local_model, get_device_info |
|
|
import torch |
|
|
|
|
|
|
|
|
_MODEL_BASE = Path(__file__).parent.parent / "mistral7b-finetuned-ahb2apb" |
|
|
DEFAULT_MODEL_PATH = str(_MODEL_BASE) |
|
|
|
|
|
|
|
|
API_KEYS = set() |
|
|
API_KEY_FILE = "api_keys.txt" |
|
|
|
|
|
|
|
|
def load_api_keys(): |
|
|
"""Load API keys from file or create default""" |
|
|
global API_KEYS |
|
|
|
|
|
if os.path.exists(API_KEY_FILE): |
|
|
with open(API_KEY_FILE, 'r') as f: |
|
|
API_KEYS = {line.strip() for line in f if line.strip()} |
|
|
else: |
|
|
|
|
|
default_key = secrets.token_urlsafe(32) |
|
|
with open(API_KEY_FILE, 'w') as f: |
|
|
f.write(default_key + '\n') |
|
|
API_KEYS = {default_key} |
|
|
print(f"\n๐ Generated default API key: {default_key}") |
|
|
print(f" Save this key! Store it in: {API_KEY_FILE}") |
|
|
|
|
|
print(f"โ Loaded {len(API_KEYS)} API key(s)") |
|
|
|
|
|
def verify_api_key(api_key: str = Header(None)): |
|
|
"""Verify API key in request header""" |
|
|
if api_key is None: |
|
|
raise HTTPException( |
|
|
status_code=401, |
|
|
detail="API key required. Add header: 'X-API-Key: your-api-key'" |
|
|
) |
|
|
if api_key not in API_KEYS: |
|
|
raise HTTPException(status_code=403, detail="Invalid API key") |
|
|
return api_key |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
device_info = None |
|
|
|
|
|
app = FastAPI( |
|
|
title="Mistral 7B AHB2APB API (Secure)", |
|
|
description="Secure API for serving the fine-tuned Mistral 7B model for AHB2APB conversion", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
security = HTTPBearer(auto_error=False) |
|
|
|
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str |
|
|
max_length: Optional[int] = 512 |
|
|
temperature: Optional[float] = 0.7 |
|
|
|
|
|
class GenerateResponse(BaseModel): |
|
|
response: str |
|
|
model: str |
|
|
max_length: int |
|
|
temperature: float |
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status: str |
|
|
model_loaded: bool |
|
|
device: str |
|
|
model_path: str |
|
|
authentication: str |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup(): |
|
|
"""Load model and API keys on startup""" |
|
|
global model, tokenizer, device_info |
|
|
load_api_keys() |
|
|
|
|
|
model_path = os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH) |
|
|
|
|
|
print(f"\nLoading model from: {model_path}") |
|
|
print("=" * 70) |
|
|
|
|
|
try: |
|
|
device_info = get_device_info() |
|
|
model, tokenizer = load_local_model(model_path) |
|
|
print(f"\nโ Model loaded successfully on {device_info['device']}!") |
|
|
print(f"โ API server ready (authentication enabled)") |
|
|
print("=" * 70) |
|
|
except Exception as e: |
|
|
print(f"\nโ Error loading model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
"""Health check endpoint (no auth required)""" |
|
|
return HealthResponse( |
|
|
status="healthy" if model is not None else "error", |
|
|
model_loaded=model is not None, |
|
|
device=device_info["device"] if device_info else "unknown", |
|
|
model_path=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH), |
|
|
authentication="enabled" |
|
|
) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint with API information""" |
|
|
return { |
|
|
"name": "Mistral 7B AHB2APB API (Secure)", |
|
|
"version": "1.0.0", |
|
|
"status": "running", |
|
|
"authentication": "API key required", |
|
|
"model": os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH), |
|
|
"endpoints": { |
|
|
"health": "/health", |
|
|
"generate": "/api/generate (requires API key)", |
|
|
"docs": "/docs" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.post("/api/generate", response_model=GenerateResponse) |
|
|
async def generate(request: GenerateRequest, api_key: str = Depends(verify_api_key)): |
|
|
"""Generate text from a prompt (requires API key)""" |
|
|
if model is None or tokenizer is None: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
try: |
|
|
response = generate_with_local_model( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
prompt=request.prompt, |
|
|
max_length=request.max_length or 512, |
|
|
temperature=request.temperature or 0.7 |
|
|
) |
|
|
|
|
|
return GenerateResponse( |
|
|
response=response, |
|
|
model=os.environ.get("MODEL_PATH", DEFAULT_MODEL_PATH), |
|
|
max_length=request.max_length or 512, |
|
|
temperature=request.temperature or 0.7 |
|
|
) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Generation error: {str(e)}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Start Secure Mistral 7B API server") |
|
|
parser.add_argument("--model-path", type=str, default=DEFAULT_MODEL_PATH) |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
|
parser.add_argument("--port", type=int, default=8000) |
|
|
parser.add_argument("--reload", action="store_true") |
|
|
|
|
|
args = parser.parse_args() |
|
|
os.environ["MODEL_PATH"] = args.model_path |
|
|
|
|
|
print(f"\n๐ Starting Secure Mistral 7B AHB2APB API Server") |
|
|
print(f" Model: {args.model_path}") |
|
|
print(f" Host: {args.host}") |
|
|
print(f" Port: {args.port}\n") |
|
|
|
|
|
|
|
|
import os |
|
|
os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
uvicorn.run("api_server_secure:app", host=args.host, port=args.port, reload=args.reload) |
|
|
|
|
|
|