olifant-generate-server / server /timbl_api_server.py
Antal van den Bosch
Remove packages.txt entirely - not needed for HTTP client
d0ff18b unverified
#!/usr/bin/env python3
"""
Secure HTTP API wrapper for TiMBL server.
Run this on your server instead of raw timblserver.
Features:
- Bearer token authentication
- HTTPS support (via reverse proxy or uvicorn)
- Rate limiting
- Request logging
"""
import os
import timbl
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional
import time
from collections import defaultdict
import logging
# Configuration
API_TOKEN = os.environ.get("TIMBL_API_TOKEN", "change-me-to-a-secure-token")
MODEL_PATH = os.environ.get("TIMBL_MODEL_PATH", "model.igtree")
TIMBL_ARGS = os.environ.get("TIMBL_ARGS", "-a1 +D")
RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "100")) # per minute
RATE_LIMIT_WINDOW = 60 # seconds
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="OLIFANT TiMBL API",
description="Secure API for TiMBL classification",
version="1.0.0"
)
# CORS - adjust origins as needed
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Restrict to your HF Space URL in production
allow_credentials=True,
allow_methods=["POST"],
allow_headers=["*"],
)
# Rate limiting storage
request_counts: Dict[str, List[float]] = defaultdict(list)
class ClassifyRequest(BaseModel):
features: List[str]
class ClassifyResponse(BaseModel):
label: str
distribution: Dict[str, float]
distance: float
class HealthResponse(BaseModel):
status: str
model_loaded: bool
# Initialize TiMBL
classifier = None
model_loaded = False
@app.on_event("startup")
async def load_model():
global classifier, model_loaded
logger.info(f"Loading TiMBL model from {MODEL_PATH}...")
try:
classifier = timbl.TimblClassifier(MODEL_PATH, TIMBL_ARGS)
classifier.load()
model_loaded = True
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Failed to load model: {e}")
model_loaded = False
def verify_token(authorization: str = Header(...)) -> str:
"""Verify bearer token"""
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer " prefix
if token != API_TOKEN:
raise HTTPException(status_code=401, detail="Invalid token")
return token
def check_rate_limit(client_ip: str = "default"):
"""Simple rate limiting"""
now = time.time()
window_start = now - RATE_LIMIT_WINDOW
# Clean old entries
request_counts[client_ip] = [t for t in request_counts[client_ip] if t > window_start]
if len(request_counts[client_ip]) >= RATE_LIMIT_REQUESTS:
raise HTTPException(status_code=429, detail="Rate limit exceeded")
request_counts[client_ip].append(now)
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint (no auth required)"""
return HealthResponse(
status="ok" if model_loaded else "error",
model_loaded=model_loaded
)
@app.post("/classify", response_model=ClassifyResponse)
async def classify(
request: ClassifyRequest,
token: str = Depends(verify_token)
):
"""
Classify a feature vector.
Requires Bearer token authentication.
"""
check_rate_limit()
if not model_loaded or classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
label, distribution, distance = classifier.classify(request.features)
# Convert distribution to proper dict with float values
dist_dict = {k: float(v) for k, v in distribution.items()} if distribution else {}
return ClassifyResponse(
label=label,
distribution=dist_dict,
distance=float(distance) if distance else 0.0
)
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "8000"))
print(f"""
========================================
OLIFANT TiMBL API Server
========================================
Starting server on port {port}
Set these environment variables:
TIMBL_API_TOKEN - Secret token for authentication
TIMBL_MODEL_PATH - Path to .ibase file
TIMBL_ARGS - TiMBL arguments (default: -a1 +D)
For HTTPS, run behind nginx/Caddy or use:
uvicorn timbl_api_server:app --host 0.0.0.0 --port 8000 --ssl-keyfile key.pem --ssl-certfile cert.pem
========================================
""")
uvicorn.run(app, host="0.0.0.0", port=port)