Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Secure HTTP API wrapper for TiMBL server. | |
| Run this on your server instead of raw timblserver. | |
| Features: | |
| - Bearer token authentication | |
| - HTTPS support (via reverse proxy or uvicorn) | |
| - Rate limiting | |
| - Request logging | |
| """ | |
| import os | |
| import timbl | |
| from fastapi import FastAPI, HTTPException, Depends, Header | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Optional | |
| import time | |
| from collections import defaultdict | |
| import logging | |
| # Configuration | |
| API_TOKEN = os.environ.get("TIMBL_API_TOKEN", "change-me-to-a-secure-token") | |
| MODEL_PATH = os.environ.get("TIMBL_MODEL_PATH", "model.igtree") | |
| TIMBL_ARGS = os.environ.get("TIMBL_ARGS", "-a1 +D") | |
| RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "100")) # per minute | |
| RATE_LIMIT_WINDOW = 60 # seconds | |
| # Logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="OLIFANT TiMBL API", | |
| description="Secure API for TiMBL classification", | |
| version="1.0.0" | |
| ) | |
| # CORS - adjust origins as needed | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Restrict to your HF Space URL in production | |
| allow_credentials=True, | |
| allow_methods=["POST"], | |
| allow_headers=["*"], | |
| ) | |
| # Rate limiting storage | |
| request_counts: Dict[str, List[float]] = defaultdict(list) | |
| class ClassifyRequest(BaseModel): | |
| features: List[str] | |
| class ClassifyResponse(BaseModel): | |
| label: str | |
| distribution: Dict[str, float] | |
| distance: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| # Initialize TiMBL | |
| classifier = None | |
| model_loaded = False | |
| async def load_model(): | |
| global classifier, model_loaded | |
| logger.info(f"Loading TiMBL model from {MODEL_PATH}...") | |
| try: | |
| classifier = timbl.TimblClassifier(MODEL_PATH, TIMBL_ARGS) | |
| classifier.load() | |
| model_loaded = True | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| model_loaded = False | |
| def verify_token(authorization: str = Header(...)) -> str: | |
| """Verify bearer token""" | |
| if not authorization.startswith("Bearer "): | |
| raise HTTPException(status_code=401, detail="Invalid authorization header") | |
| token = authorization[7:] # Remove "Bearer " prefix | |
| if token != API_TOKEN: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| return token | |
| def check_rate_limit(client_ip: str = "default"): | |
| """Simple rate limiting""" | |
| now = time.time() | |
| window_start = now - RATE_LIMIT_WINDOW | |
| # Clean old entries | |
| request_counts[client_ip] = [t for t in request_counts[client_ip] if t > window_start] | |
| if len(request_counts[client_ip]) >= RATE_LIMIT_REQUESTS: | |
| raise HTTPException(status_code=429, detail="Rate limit exceeded") | |
| request_counts[client_ip].append(now) | |
| async def health_check(): | |
| """Health check endpoint (no auth required)""" | |
| return HealthResponse( | |
| status="ok" if model_loaded else "error", | |
| model_loaded=model_loaded | |
| ) | |
| async def classify( | |
| request: ClassifyRequest, | |
| token: str = Depends(verify_token) | |
| ): | |
| """ | |
| Classify a feature vector. | |
| Requires Bearer token authentication. | |
| """ | |
| check_rate_limit() | |
| if not model_loaded or classifier is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| try: | |
| label, distribution, distance = classifier.classify(request.features) | |
| # Convert distribution to proper dict with float values | |
| dist_dict = {k: float(v) for k, v in distribution.items()} if distribution else {} | |
| return ClassifyResponse( | |
| label=label, | |
| distribution=dist_dict, | |
| distance=float(distance) if distance else 0.0 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Classification error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", "8000")) | |
| print(f""" | |
| ======================================== | |
| OLIFANT TiMBL API Server | |
| ======================================== | |
| Starting server on port {port} | |
| Set these environment variables: | |
| TIMBL_API_TOKEN - Secret token for authentication | |
| TIMBL_MODEL_PATH - Path to .ibase file | |
| TIMBL_ARGS - TiMBL arguments (default: -a1 +D) | |
| For HTTPS, run behind nginx/Caddy or use: | |
| uvicorn timbl_api_server:app --host 0.0.0.0 --port 8000 --ssl-keyfile key.pem --ssl-certfile cert.pem | |
| ======================================== | |
| """) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |