File size: 5,199 Bytes
88b8fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
import logging
from pathlib import Path
import sys
import os
from huggingface_hub import snapshot_download

# Add parent directory to path for imports
sys.path.append(str(Path(__file__).parent))

from binary_classifier import CBTBinaryClassifier

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Create FastAPI app
app = FastAPI(
    title="CBT Binary Classifier API",
    description="API for detecting CBT-triggering conversations",
    version="1.0.0"
)

# Request/Response models
class TextRequest(BaseModel):
    text: str = Field(..., description="Text to classify")
    threshold: float = Field(0.7, description="Confidence threshold for CBT trigger detection")

class BatchTextRequest(BaseModel):
    texts: List[str] = Field(..., description="List of texts to classify")
    threshold: float = Field(0.7, description="Confidence threshold for CBT trigger detection")

class PredictionResponse(BaseModel):
    is_cbt_trigger: bool
    confidence: float
    threshold: float
    text: Optional[str] = None

class BatchPredictionResponse(BaseModel):
    predictions: List[PredictionResponse]

# Initialize classifier
classifier = None

@app.on_event("startup")
async def startup_event():
    """Load the model on startup"""
    global classifier
    try:
        classifier = CBTBinaryClassifier()
        
        # Try to load from Hugging Face Hub first
        hf_model_id = os.getenv("HF_MODEL_ID", "SaitejaJate/Binary_classifier")
        local_model_path = Path(__file__).parent / "cbt_classifier"
        
        # Check if we should use local model or download from HF
        use_local = os.getenv("USE_LOCAL_MODEL", "false").lower() == "true"
        
        if use_local and local_model_path.exists():
            # Use local model
            classifier.load_model(str(local_model_path))
            logger.info(f"Model loaded successfully from local path: {local_model_path}")
        else:
            # Download from Hugging Face Hub
            logger.info(f"Downloading model from Hugging Face Hub: {hf_model_id}")
            cache_dir = Path(__file__).parent / "model_cache"
            
            # Download model files
            model_path = snapshot_download(
                repo_id=hf_model_id,
                cache_dir=str(cache_dir),
                local_dir=str(cache_dir / "downloaded_model")
            )
            
            classifier.load_model(model_path)
            logger.info(f"Model loaded successfully from Hugging Face Hub")
            
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise

@app.get("/")
async def root():
    """Health check endpoint"""
    return {
        "status": "active",
        "service": "CBT Binary Classifier API",
        "model_loaded": classifier is not None
    }

@app.post("/classify", response_model=PredictionResponse)
async def classify_text(request: TextRequest):
    """Classify a single text"""
    try:
        if classifier is None:
            raise HTTPException(status_code=503, detail="Model not loaded")
        
        result = classifier.predict(request.text, request.threshold)
        
        return PredictionResponse(
            is_cbt_trigger=result['is_cbt_trigger'],
            confidence=result['confidence'],
            threshold=result['threshold'],
            text=request.text[:100] + "..." if len(request.text) > 100 else request.text
        )
    except Exception as e:
        logger.error(f"Classification error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/classify/batch", response_model=BatchPredictionResponse)
async def classify_batch(request: BatchTextRequest):
    """Classify multiple texts"""
    try:
        if classifier is None:
            raise HTTPException(status_code=503, detail="Model not loaded")
        
        results = classifier.batch_predict(request.texts, request.threshold)
        
        predictions = []
        for i, result in enumerate(results):
            text_preview = request.texts[i][:100] + "..." if len(request.texts[i]) > 100 else request.texts[i]
            predictions.append(PredictionResponse(
                is_cbt_trigger=result['is_cbt_trigger'],
                confidence=result['confidence'],
                threshold=result['threshold'],
                text=text_preview
            ))
        
        return BatchPredictionResponse(predictions=predictions)
    except Exception as e:
        logger.error(f"Batch classification error: {e}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/model/info")
async def model_info():
    """Get information about the loaded model"""
    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    return {
        "model_name": classifier.model_name,
        "model_path": str(Path(__file__).parent / "cbt_classifier"),
        "status": "loaded"
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8001)