File size: 4,847 Bytes
d0ff18b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
"""
Secure HTTP API wrapper for TiMBL server.
Run this on your server instead of raw timblserver.

Features:
- Bearer token authentication
- HTTPS support (via reverse proxy or uvicorn)
- Rate limiting
- Request logging
"""

import os
import timbl
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional
import time
from collections import defaultdict
import logging

# Configuration
API_TOKEN = os.environ.get("TIMBL_API_TOKEN", "change-me-to-a-secure-token")
MODEL_PATH = os.environ.get("TIMBL_MODEL_PATH", "model.igtree")
TIMBL_ARGS = os.environ.get("TIMBL_ARGS", "-a1 +D")
RATE_LIMIT_REQUESTS = int(os.environ.get("RATE_LIMIT_REQUESTS", "100"))  # per minute
RATE_LIMIT_WINDOW = 60  # seconds

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="OLIFANT TiMBL API",
    description="Secure API for TiMBL classification",
    version="1.0.0"
)

# CORS - adjust origins as needed
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Restrict to your HF Space URL in production
    allow_credentials=True,
    allow_methods=["POST"],
    allow_headers=["*"],
)

# Rate limiting storage
request_counts: Dict[str, List[float]] = defaultdict(list)


class ClassifyRequest(BaseModel):
    features: List[str]


class ClassifyResponse(BaseModel):
    label: str
    distribution: Dict[str, float]
    distance: float


class HealthResponse(BaseModel):
    status: str
    model_loaded: bool


# Initialize TiMBL
classifier = None
model_loaded = False


@app.on_event("startup")
async def load_model():
    global classifier, model_loaded
    logger.info(f"Loading TiMBL model from {MODEL_PATH}...")
    try:
        classifier = timbl.TimblClassifier(MODEL_PATH, TIMBL_ARGS)
        classifier.load()
        model_loaded = True
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        model_loaded = False


def verify_token(authorization: str = Header(...)) -> str:
    """Verify bearer token"""
    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid authorization header")

    token = authorization[7:]  # Remove "Bearer " prefix
    if token != API_TOKEN:
        raise HTTPException(status_code=401, detail="Invalid token")

    return token


def check_rate_limit(client_ip: str = "default"):
    """Simple rate limiting"""
    now = time.time()
    window_start = now - RATE_LIMIT_WINDOW

    # Clean old entries
    request_counts[client_ip] = [t for t in request_counts[client_ip] if t > window_start]

    if len(request_counts[client_ip]) >= RATE_LIMIT_REQUESTS:
        raise HTTPException(status_code=429, detail="Rate limit exceeded")

    request_counts[client_ip].append(now)


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint (no auth required)"""
    return HealthResponse(
        status="ok" if model_loaded else "error",
        model_loaded=model_loaded
    )


@app.post("/classify", response_model=ClassifyResponse)
async def classify(
    request: ClassifyRequest,
    token: str = Depends(verify_token)
):
    """
    Classify a feature vector.

    Requires Bearer token authentication.
    """
    check_rate_limit()

    if not model_loaded or classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        label, distribution, distance = classifier.classify(request.features)

        # Convert distribution to proper dict with float values
        dist_dict = {k: float(v) for k, v in distribution.items()} if distribution else {}

        return ClassifyResponse(
            label=label,
            distribution=dist_dict,
            distance=float(distance) if distance else 0.0
        )
    except Exception as e:
        logger.error(f"Classification error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
    import uvicorn

    port = int(os.environ.get("PORT", "8000"))

    print(f"""
    ========================================
    OLIFANT TiMBL API Server
    ========================================

    Starting server on port {port}

    Set these environment variables:
      TIMBL_API_TOKEN  - Secret token for authentication
      TIMBL_MODEL_PATH - Path to .ibase file
      TIMBL_ARGS       - TiMBL arguments (default: -a1 +D)

    For HTTPS, run behind nginx/Caddy or use:
      uvicorn timbl_api_server:app --host 0.0.0.0 --port 8000 --ssl-keyfile key.pem --ssl-certfile cert.pem

    ========================================
    """)

    uvicorn.run(app, host="0.0.0.0", port=port)