ml-model-api / app.py
SagarChhabriya's picture
Update app.py
33ad2a3 verified
############################ HuggingFace + GitHub ###############################
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import numpy as np
import requests
import tempfile
import os
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="ML Model API", version="1.0.0")
# Add CORS middleware - THIS IS CRITICAL FOR HUGGING FACE
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
# GitHub raw content URLs - using YOUR actual GitHub repository
GITHUB_MODEL_URL = "https://raw.githubusercontent.com/SagarChhabriya/ml-model-api/main/model/model.joblib"
GITHUB_FEATURES_URL = "https://raw.githubusercontent.com/SagarChhabriya/ml-model-api/main/model/feature_names.joblib"
# Load model from GitHub
def load_model_from_github():
try:
print("πŸ“₯ Downloading model from GitHub...")
# Download model file
model_response = requests.get(GITHUB_MODEL_URL)
model_response.raise_for_status() # Raise error if download fails
# Download feature names file
features_response = requests.get(GITHUB_FEATURES_URL)
features_response.raise_for_status()
# Save to temporary files
with tempfile.NamedTemporaryFile(delete=False, suffix='.joblib') as model_tmp:
model_tmp.write(model_response.content)
model_path = model_tmp.name
with tempfile.NamedTemporaryFile(delete=False, suffix='.joblib') as features_tmp:
features_tmp.write(features_response.content)
features_path = features_tmp.name
# Load the files
model = joblib.load(model_path)
feature_names = joblib.load(features_path)
# Clean up temporary files
os.unlink(model_path)
os.unlink(features_path)
print("βœ… Model loaded successfully from GitHub!")
print(f"πŸ“Š Model type: {type(model).__name__}")
print(f"πŸ“ˆ Features: {feature_names}")
return model, feature_names
except Exception as e:
print(f"❌ Error loading from GitHub: {e}")
return None, []
# Load model on startup
model, feature_names = load_model_from_github()
# Define input schema
class PredictionInput(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
class BatchPredictionInput(BaseModel):
data: list[list[float]]
@app.get("/")
async def root():
return {
"message": "ML Model API deployed on Hugging Face Spaces! πŸš€",
"endpoints": {
"health": "/health",
"single_prediction": "/predict",
"batch_prediction": "/predict-batch",
"model_info": "/model-info",
"docs": "/docs"
},
"model_loaded": model is not None,
"model_source": "GitHub"
}
@app.get("/health")
async def health_check():
return {
"status": "healthy" if model else "unhealthy",
"model_loaded": model is not None,
"model_type": "RandomForestClassifier" if model else "None"
}
@app.get("/model-info")
async def model_info():
if not model:
raise HTTPException(status_code=500, detail="Model not loaded")
return {
"model_type": str(type(model).__name__),
"feature_names": feature_names,
"n_features": len(feature_names),
"n_classes": getattr(model, 'n_classes_', 'Unknown')
}
@app.post("/predict")
async def predict_single(input_data: PredictionInput):
if not model:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
# Convert input to array
features = np.array([
input_data.sepal_length,
input_data.sepal_width,
input_data.petal_length,
input_data.petal_width
]).reshape(1, -1)
# Make prediction
prediction = model.predict(features)
probabilities = model.predict_proba(features)
return {
"prediction": int(prediction[0]),
"probabilities": probabilities[0].tolist(),
"class_names": ["setosa", "versicolor", "virginica"],
"input_features": input_data.dict()
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")
@app.post("/predict-batch")
async def predict_batch(input_data: BatchPredictionInput):
if not model:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
# Convert to numpy array
features = np.array(input_data.data)
# Validate input shape
if features.shape[1] != len(feature_names):
raise HTTPException(
status_code=400,
detail=f"Expected {len(feature_names)} features, got {features.shape[1]}"
)
# Make predictions
predictions = model.predict(features)
probabilities = model.predict_proba(features)
return {
"predictions": predictions.tolist(),
"probabilities": probabilities.tolist(),
"batch_size": len(predictions)
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")
@app.get("/debug")
async def debug():
"""Debug endpoint to check model loading status"""
return {
"model_loaded": model is not None,
"features_loaded": len(feature_names) > 0 if feature_names else False,
"feature_names": feature_names if feature_names else "Not loaded",
"model_type": str(type(model).__name__) if model else "Not loaded",
"github_model_url": GITHUB_MODEL_URL,
"github_features_url": GITHUB_FEATURES_URL
}