File size: 5,199 Bytes
88b8fd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
import logging
from pathlib import Path
import sys
import os
from huggingface_hub import snapshot_download
# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent))
from binary_classifier import CBTBinaryClassifier
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Create FastAPI app
app = FastAPI(
title="CBT Binary Classifier API",
description="API for detecting CBT-triggering conversations",
version="1.0.0"
)
# Request/Response models
class TextRequest(BaseModel):
text: str = Field(..., description="Text to classify")
threshold: float = Field(0.7, description="Confidence threshold for CBT trigger detection")
class BatchTextRequest(BaseModel):
texts: List[str] = Field(..., description="List of texts to classify")
threshold: float = Field(0.7, description="Confidence threshold for CBT trigger detection")
class PredictionResponse(BaseModel):
is_cbt_trigger: bool
confidence: float
threshold: float
text: Optional[str] = None
class BatchPredictionResponse(BaseModel):
predictions: List[PredictionResponse]
# Initialize classifier
classifier = None
@app.on_event("startup")
async def startup_event():
"""Load the model on startup"""
global classifier
try:
classifier = CBTBinaryClassifier()
# Try to load from Hugging Face Hub first
hf_model_id = os.getenv("HF_MODEL_ID", "SaitejaJate/Binary_classifier")
local_model_path = Path(__file__).parent / "cbt_classifier"
# Check if we should use local model or download from HF
use_local = os.getenv("USE_LOCAL_MODEL", "false").lower() == "true"
if use_local and local_model_path.exists():
# Use local model
classifier.load_model(str(local_model_path))
logger.info(f"Model loaded successfully from local path: {local_model_path}")
else:
# Download from Hugging Face Hub
logger.info(f"Downloading model from Hugging Face Hub: {hf_model_id}")
cache_dir = Path(__file__).parent / "model_cache"
# Download model files
model_path = snapshot_download(
repo_id=hf_model_id,
cache_dir=str(cache_dir),
local_dir=str(cache_dir / "downloaded_model")
)
classifier.load_model(model_path)
logger.info(f"Model loaded successfully from Hugging Face Hub")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"status": "active",
"service": "CBT Binary Classifier API",
"model_loaded": classifier is not None
}
@app.post("/classify", response_model=PredictionResponse)
async def classify_text(request: TextRequest):
"""Classify a single text"""
try:
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
result = classifier.predict(request.text, request.threshold)
return PredictionResponse(
is_cbt_trigger=result['is_cbt_trigger'],
confidence=result['confidence'],
threshold=result['threshold'],
text=request.text[:100] + "..." if len(request.text) > 100 else request.text
)
except Exception as e:
logger.error(f"Classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/classify/batch", response_model=BatchPredictionResponse)
async def classify_batch(request: BatchTextRequest):
"""Classify multiple texts"""
try:
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
results = classifier.batch_predict(request.texts, request.threshold)
predictions = []
for i, result in enumerate(results):
text_preview = request.texts[i][:100] + "..." if len(request.texts[i]) > 100 else request.texts[i]
predictions.append(PredictionResponse(
is_cbt_trigger=result['is_cbt_trigger'],
confidence=result['confidence'],
threshold=result['threshold'],
text=text_preview
))
return BatchPredictionResponse(predictions=predictions)
except Exception as e:
logger.error(f"Batch classification error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/model/info")
async def model_info():
"""Get information about the loaded model"""
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {
"model_name": classifier.model_name,
"model_path": str(Path(__file__).parent / "cbt_classifier"),
"status": "loaded"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001) |