nomandiu9's picture
Fix scikit-learn version mismatch (upgrade to 1.6.1)
9499f51
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib
import os
from pathlib import Path
from huggingface_hub import hf_hub_download
# Initialize FastAPI app
app = FastAPI(
title="Sentiment Analysis API",
description="API for sentiment analysis using ensemble ML models",
version="1.0.0"
)
# Configure CORS - allow all origins for Vercel frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with your Vercel domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Define request/response models
class TextInput(BaseModel):
text: str
class PredictionResponse(BaseModel):
predicted_sentiment: str
input_text: str
class StatusResponse(BaseModel):
status: str
model_name: str
message: str
# Global variables for models
label_encoder = None
tfidf_vectorizer = None
voting_classifier = None
MODEL_NAME = "Voting Classifier (KNN + RF + ET)"
# HuggingFace Model Hub configuration
REPO_ID = "anis80/anisproject" # Your HuggingFace model repository
MODEL_FILES = {
"label_encoder": "label_encoder.joblib",
"tfidf_vectorizer": "tfidf_vectorizer.joblib",
"voting_classifier": "voting_classifier_knn_rf_et-001.joblib"
}
def download_model_from_hub(filename: str) -> str:
"""Download a model file from HuggingFace Model Hub"""
try:
print(f"πŸ“₯ Downloading {filename} from HuggingFace Model Hub...")
file_path = hf_hub_download(
repo_id=REPO_ID,
filename=filename,
cache_dir="./model_cache"
)
print(f"βœ… Downloaded {filename}")
return file_path
except Exception as e:
print(f"❌ Error downloading {filename}: {str(e)}")
raise e
# Load models on startup
@app.on_event("startup")
async def load_models():
global label_encoder, tfidf_vectorizer, voting_classifier
try:
print(f"πŸš€ Starting model loading from HuggingFace: {REPO_ID}")
# Download and load each model
label_encoder_path = download_model_from_hub(MODEL_FILES["label_encoder"])
label_encoder = joblib.load(label_encoder_path)
tfidf_path = download_model_from_hub(MODEL_FILES["tfidf_vectorizer"])
tfidf_vectorizer = joblib.load(tfidf_path)
classifier_path = download_model_from_hub(MODEL_FILES["voting_classifier"])
voting_classifier = joblib.load(classifier_path)
print("βœ… All models loaded successfully from HuggingFace Model Hub!")
except Exception as e:
print(f"❌ Error loading models: {str(e)}")
print(f"⚠️ Make sure models are uploaded to: https://huggingface.co/{REPO_ID}")
raise e
# Health check endpoint
@app.get("/")
async def root():
return {
"message": "Sentiment Analysis API is running",
"model_source": f"HuggingFace: {REPO_ID}",
"endpoints": {
"predict": "/predict",
"status": "/status",
"docs": "/docs"
}
}
# Status endpoint
@app.get("/status", response_model=StatusResponse)
async def get_status():
if voting_classifier is None:
raise HTTPException(status_code=503, detail="Models not loaded")
return StatusResponse(
status="ready",
model_name=MODEL_NAME,
message=f"All models loaded from {REPO_ID}"
)
# Prediction endpoint
@app.post("/predict", response_model=PredictionResponse)
async def predict_sentiment(input_data: TextInput):
try:
# Validate models are loaded
if None in [label_encoder, tfidf_vectorizer, voting_classifier]:
raise HTTPException(
status_code=503,
detail="Models not loaded. Please try again later."
)
# Validate input
if not input_data.text or not input_data.text.strip():
raise HTTPException(
status_code=400,
detail="Text input cannot be empty"
)
# Preprocess and transform the text
try:
print(f"Transforming text: '{input_data.text}'")
text_tfidf = tfidf_vectorizer.transform([input_data.text])
print(f"Shape: {text_tfidf.shape}")
# Make prediction
print("Predicting...")
prediction = voting_classifier.predict(text_tfidf)
print(f"Raw prediction: {prediction}")
# Decode the prediction
sentiment = label_encoder.inverse_transform(prediction)[0]
print(f"Sentiment: {sentiment}")
return PredictionResponse(
predicted_sentiment=sentiment,
input_text=input_data.text
)
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(f"Prediction logic error: {str(e)}\n{error_trace}")
raise HTTPException(
status_code=500,
detail=f"Model error: {str(e)}"
)
except HTTPException:
raise
except Exception as e:
print(f"Prediction error: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Prediction failed: {str(e)}"
)
# For local testing
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)