Spaces:
Build error
Build error
| # app.py | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import logging | |
| import numpy as np | |
| from typing import List, Optional, Dict, Any, Union | |
| import sys | |
| import os | |
| # Import the InterestClassifier from your model file | |
| # Make sure this file is in the same directory as app.py | |
| from hybrid_interest_classifier import InterestClassifier, INTEREST_CATEGORIES | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Allow CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define keyword-based interest detection as fallback | |
| def keyword_interests(text): | |
| """ | |
| Determine interests using keyword matching as a fallback | |
| """ | |
| text = text.lower() | |
| interests = [] | |
| if any(word in text for word in ['music', 'band', 'concert', 'sing', 'guitar', 'song']): | |
| interests.append('Music') | |
| if any(word in text for word in ['food', 'cook', 'recipe', 'restaurant', 'eat', 'cuisine']): | |
| interests.append('Food') | |
| if any(word in text for word in ['sport', 'gym', 'fitness', 'exercise', 'workout', 'run']): | |
| interests.append('Sports') | |
| if any(word in text for word in ['art', 'paint', 'draw', 'gallery', 'museum', 'exhibition']): | |
| interests.append('Arts') | |
| if any(word in text for word in ['tech', 'code', 'software', 'computer', 'programming']): | |
| interests.append('Technology') | |
| if any(word in text for word in ['learn', 'study', 'course', 'book', 'read', 'class']): | |
| interests.append('Education') | |
| if any(word in text for word in ['travel', 'trip', 'journey', 'explore', 'hike', 'tourism']): | |
| interests.append('Travel') | |
| if not interests: | |
| interests.append('No specific interests detected') | |
| return interests | |
| # Load the hybrid classifier | |
| MODEL_PATH = "hybrid_interest_classifier.pkl" | |
| hybrid_classifier = None | |
| try: | |
| logger.info(f"Loading hybrid model from {MODEL_PATH}") | |
| # Create an instance of our classifier and load the model | |
| hybrid_classifier = InterestClassifier(model_path=MODEL_PATH) | |
| logger.info("Hybrid model loaded successfully") | |
| # Log if BERT is available | |
| if hybrid_classifier.bert_classifier is not None: | |
| logger.info("BERT zero-shot classifier initialized and ready") | |
| else: | |
| logger.warning("BERT zero-shot classifier is not available, will use TF-IDF only") | |
| except Exception as e: | |
| logger.error(f"Failed to load hybrid model: {e}") | |
| # Pydantic models | |
| class PredictionRequest(BaseModel): | |
| text: str | |
| alpha: Optional[float] = None | |
| threshold: Optional[float] = None | |
| return_scores: Optional[bool] = False | |
| class ModelConfigRequest(BaseModel): | |
| alpha: Optional[float] = None | |
| threshold: Optional[float] = None | |
| async def root(): | |
| """Root endpoint to check if API is running""" | |
| return { | |
| "status": "online", | |
| "message": "Hybrid Interest Classifier API is running", | |
| "model_loaded": hybrid_classifier is not None, | |
| "bert_available": hybrid_classifier.bert_classifier is not None if hybrid_classifier else False | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": hybrid_classifier is not None, | |
| "bert_available": hybrid_classifier.bert_classifier is not None if hybrid_classifier else False | |
| } | |
| async def update_config(config: ModelConfigRequest): | |
| """Update model configuration""" | |
| if hybrid_classifier is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded") | |
| changes = {} | |
| if config.alpha is not None: | |
| hybrid_classifier.alpha = float(config.alpha) | |
| changes["alpha"] = hybrid_classifier.alpha | |
| if config.threshold is not None: | |
| hybrid_classifier.threshold = float(config.threshold) | |
| changes["threshold"] = hybrid_classifier.threshold | |
| return { | |
| "message": "Configuration updated successfully", | |
| "changes": changes, | |
| "current_config": { | |
| "alpha": hybrid_classifier.alpha, | |
| "threshold": hybrid_classifier.threshold, | |
| "bert_available": hybrid_classifier.bert_classifier is not None | |
| } | |
| } | |
| async def predict(request: PredictionRequest): | |
| """ | |
| Predict interests based on text input | |
| """ | |
| text = request.text | |
| alpha = request.alpha | |
| threshold = request.threshold | |
| return_scores = request.return_scores | |
| logger.info(f"Prediction request: text='{text[:50]}...', alpha={alpha}, threshold={threshold}, return_scores={return_scores}") | |
| if not text or text.strip() == "": | |
| return {"labels": ["No text provided"], "text": text} | |
| if hybrid_classifier is None: | |
| logger.warning("Using fallback keyword matching (model not loaded)") | |
| return {"labels": keyword_interests(text), "text": text} | |
| try: | |
| # Prepare prediction parameters | |
| kwargs = {} | |
| if alpha is not None: | |
| kwargs['alpha'] = alpha | |
| if threshold is not None: | |
| kwargs['threshold'] = threshold | |
| if return_scores: | |
| kwargs['return_scores'] = True | |
| # Log the call we're about to make | |
| logger.info(f"Calling hybrid_classifier.predict([{text[:20]}...], {kwargs})") | |
| # Make prediction | |
| prediction = hybrid_classifier.predict(text, **kwargs) | |
| logger.info(f"Raw prediction type: {type(prediction)}") | |
| # Process the prediction result | |
| labels = [] | |
| scores = {} | |
| # Handle dictionary return type (with return_scores=True) | |
| if isinstance(prediction, dict): | |
| labels = prediction.get('labels', []) | |
| # Include detailed information in response if available | |
| if return_scores: | |
| response = { | |
| "labels": labels, | |
| "text": text, | |
| "scores": dict(prediction.get('sorted_scores', [])), | |
| "model_info": { | |
| "alpha": prediction.get('alpha', hybrid_classifier.alpha), | |
| "threshold": prediction.get('threshold', hybrid_classifier.threshold), | |
| "using_bert": prediction.get('using_bert', False) | |
| } | |
| } | |
| # Add timing information if available | |
| if 'timing' in prediction: | |
| response["timing"] = prediction['timing'] | |
| # Include individual model scores | |
| if 'tfidf_scores' in prediction: | |
| response["tfidf_scores"] = dict(sorted( | |
| prediction['tfidf_scores'].items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| )[:5]) | |
| if 'bert_scores' in prediction: | |
| response["bert_scores"] = dict(sorted( | |
| prediction['bert_scores'].items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| )[:5]) | |
| return response | |
| # Handle list return type | |
| elif isinstance(prediction, list): | |
| labels = prediction | |
| # If we still have no labels, use keyword matching | |
| if not labels: | |
| logger.warning("No labels detected, using fallback") | |
| labels = keyword_interests(text) | |
| # Return simple response without scores | |
| return {"labels": labels, "text": text} | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {e}", exc_info=True) | |
| return {"labels": keyword_interests(text), "text": text, "error": str(e)} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |