Spaces:
Sleeping
Sleeping
File size: 4,847 Bytes
d0ff18b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
#!/usr/bin/env python3
"""
Secure HTTP API wrapper for TiMBL server.
Run this on your server instead of raw timblserver.
Features:
- Bearer token authentication
- HTTPS support (via reverse proxy or uvicorn)
- Rate limiting
- Request logging
"""
import os
import timbl
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional
import time
from collections import defaultdict
import logging
# Configuration
API_TOKEN = os.environ.get("TIMBL_API_TOKEN", "change-me-to-a-secure-token")
MODEL_PATH = os.environ.get("TIMBL_MODEL_PATH", "model.igtree")
TIMBL_ARGS = os.environ.get("TIMBL_ARGS", "-a1 +D")
RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "100")) # per minute
RATE_LIMIT_WINDOW = 60 # seconds
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="OLIFANT TiMBL API",
description="Secure API for TiMBL classification",
version="1.0.0"
)
# CORS - adjust origins as needed
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Restrict to your HF Space URL in production
allow_credentials=True,
allow_methods=["POST"],
allow_headers=["*"],
)
# Rate limiting storage
request_counts: Dict[str, List[float]] = defaultdict(list)
class ClassifyRequest(BaseModel):
features: List[str]
class ClassifyResponse(BaseModel):
label: str
distribution: Dict[str, float]
distance: float
class HealthResponse(BaseModel):
status: str
model_loaded: bool
# Initialize TiMBL
classifier = None
model_loaded = False
@app.on_event("startup")
async def load_model():
global classifier, model_loaded
logger.info(f"Loading TiMBL model from {MODEL_PATH}...")
try:
classifier = timbl.TimblClassifier(MODEL_PATH, TIMBL_ARGS)
classifier.load()
model_loaded = True
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
model_loaded = False
def verify_token(authorization: str = Header(...)) -> str:
"""Verify bearer token"""
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer " prefix
if token != API_TOKEN:
raise HTTPException(status_code=401, detail="Invalid token")
return token
def check_rate_limit(client_ip: str = "default"):
"""Simple rate limiting"""
now = time.time()
window_start = now - RATE_LIMIT_WINDOW
# Clean old entries
request_counts[client_ip] = [t for t in request_counts[client_ip] if t > window_start]
if len(request_counts[client_ip]) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Rate limit exceeded")
request_counts[client_ip].append(now)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint (no auth required)"""
return HealthResponse(
status="ok" if model_loaded else "error",
model_loaded=model_loaded
)
@app.post("/classify", response_model=ClassifyResponse)
async def classify(
request: ClassifyRequest,
token: str = Depends(verify_token)
):
"""
Classify a feature vector.
Requires Bearer token authentication.
"""
check_rate_limit()
if not model_loaded or classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
label, distribution, distance = classifier.classify(request.features)
# Convert distribution to proper dict with float values
dist_dict = {k: float(v) for k, v in distribution.items()} if distribution else {}
return ClassifyResponse(
label=label,
distribution=dist_dict,
distance=float(distance) if distance else 0.0
)
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "8000"))
print(f"""
========================================
OLIFANT TiMBL API Server
========================================
Starting server on port {port}
Set these environment variables:
TIMBL_API_TOKEN - Secret token for authentication
TIMBL_MODEL_PATH - Path to .ibase file
TIMBL_ARGS - TiMBL arguments (default: -a1 +D)
For HTTPS, run behind nginx/Caddy or use:
uvicorn timbl_api_server:app --host 0.0.0.0 --port 8000 --ssl-keyfile key.pem --ssl-certfile cert.pem
========================================
""")
uvicorn.run(app, host="0.0.0.0", port=port)
|