Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import joblib | |
| import os | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Sentiment Analysis API", | |
| description="API for sentiment analysis using ensemble ML models", | |
| version="1.0.0" | |
| ) | |
| # Configure CORS - allow all origins for Vercel frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, replace with your Vercel domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define request/response models | |
| class TextInput(BaseModel): | |
| text: str | |
| class PredictionResponse(BaseModel): | |
| predicted_sentiment: str | |
| input_text: str | |
| class StatusResponse(BaseModel): | |
| status: str | |
| model_name: str | |
| message: str | |
| # Global variables for models | |
| label_encoder = None | |
| tfidf_vectorizer = None | |
| voting_classifier = None | |
| MODEL_NAME = "Voting Classifier (KNN + RF + ET)" | |
| # HuggingFace Model Hub configuration | |
| REPO_ID = "anis80/anisproject" # Your HuggingFace model repository | |
| MODEL_FILES = { | |
| "label_encoder": "label_encoder.joblib", | |
| "tfidf_vectorizer": "tfidf_vectorizer.joblib", | |
| "voting_classifier": "voting_classifier_knn_rf_et-001.joblib" | |
| } | |
| def download_model_from_hub(filename: str) -> str: | |
| """Download a model file from HuggingFace Model Hub""" | |
| try: | |
| print(f"π₯ Downloading {filename} from HuggingFace Model Hub...") | |
| file_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=filename, | |
| cache_dir="./model_cache" | |
| ) | |
| print(f"β Downloaded {filename}") | |
| return file_path | |
| except Exception as e: | |
| print(f"β Error downloading {filename}: {str(e)}") | |
| raise e | |
| # Load models on startup | |
| async def load_models(): | |
| global label_encoder, tfidf_vectorizer, voting_classifier | |
| try: | |
| print(f"π Starting model loading from HuggingFace: {REPO_ID}") | |
| # Download and load each model | |
| label_encoder_path = download_model_from_hub(MODEL_FILES["label_encoder"]) | |
| label_encoder = joblib.load(label_encoder_path) | |
| tfidf_path = download_model_from_hub(MODEL_FILES["tfidf_vectorizer"]) | |
| tfidf_vectorizer = joblib.load(tfidf_path) | |
| classifier_path = download_model_from_hub(MODEL_FILES["voting_classifier"]) | |
| voting_classifier = joblib.load(classifier_path) | |
| print("β All models loaded successfully from HuggingFace Model Hub!") | |
| except Exception as e: | |
| print(f"β Error loading models: {str(e)}") | |
| print(f"β οΈ Make sure models are uploaded to: https://huggingface.co/{REPO_ID}") | |
| raise e | |
| # Health check endpoint | |
| async def root(): | |
| return { | |
| "message": "Sentiment Analysis API is running", | |
| "model_source": f"HuggingFace: {REPO_ID}", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "status": "/status", | |
| "docs": "/docs" | |
| } | |
| } | |
| # Status endpoint | |
| async def get_status(): | |
| if voting_classifier is None: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| return StatusResponse( | |
| status="ready", | |
| model_name=MODEL_NAME, | |
| message=f"All models loaded from {REPO_ID}" | |
| ) | |
| # Prediction endpoint | |
| async def predict_sentiment(input_data: TextInput): | |
| try: | |
| # Validate models are loaded | |
| if None in [label_encoder, tfidf_vectorizer, voting_classifier]: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Models not loaded. Please try again later." | |
| ) | |
| # Validate input | |
| if not input_data.text or not input_data.text.strip(): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Text input cannot be empty" | |
| ) | |
| # Preprocess and transform the text | |
| try: | |
| print(f"Transforming text: '{input_data.text}'") | |
| text_tfidf = tfidf_vectorizer.transform([input_data.text]) | |
| print(f"Shape: {text_tfidf.shape}") | |
| # Make prediction | |
| print("Predicting...") | |
| prediction = voting_classifier.predict(text_tfidf) | |
| print(f"Raw prediction: {prediction}") | |
| # Decode the prediction | |
| sentiment = label_encoder.inverse_transform(prediction)[0] | |
| print(f"Sentiment: {sentiment}") | |
| return PredictionResponse( | |
| predicted_sentiment=sentiment, | |
| input_text=input_data.text | |
| ) | |
| except Exception as e: | |
| import traceback | |
| error_trace = traceback.format_exc() | |
| print(f"Prediction logic error: {str(e)}\n{error_trace}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Model error: {str(e)}" | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Prediction error: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Prediction failed: {str(e)}" | |
| ) | |
| # For local testing | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |