from fastapi import FastAPI, HTTPException from pydantic import BaseModel import joblib import os import uvicorn from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, HTMLResponse from typing import Dict, Any import numpy as np app = FastAPI( title="Iris Flower Prediction API", description="A machine learning API for predicting iris flower species based on sepal length", version="1.0.0" ) # CORS Configuration for Hugging Face Spaces origins = [ "*", # Allow all origins for Hugging Face Spaces ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load the model model_path = os.path.join(os.path.dirname(__file__), 'model/model.joblib') model = joblib.load(model_path) # Serve static files app.mount("/static", StaticFiles(directory="ui"), name="static") class IrisPredictionRequest(BaseModel): sepal_length: float class Config: schema_extra = { "example": { "sepal_length": 5.1 } } class IrisPredictionResponse(BaseModel): prediction: str confidence: float sepal_length: float species_info: Dict[str, Any] @app.get("/", response_class=HTMLResponse) async def root(): """Serve the main interface""" return FileResponse("ui/index.html") @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": model is not None} @app.get("/api/info") async def get_api_info(): """Get API information""" return { "name": "Iris Flower Prediction API", "version": "1.0.0", "description": "Predict iris flower species based on sepal length", "input_features": ["sepal_length"], "output_classes": ["setosa", "versicolor", "virginica"], "model_type": "Machine Learning Classifier" } @app.post("/api/predict", response_model=IrisPredictionResponse) async def predict_iris(data: IrisPredictionRequest): """ Predict iris flower species based on sepal length Args: data: IrisPredictionRequest containing sepal_length Returns: IrisPredictionResponse with prediction and additional info """ try: # Validate input if data.sepal_length <= 0: raise HTTPException(status_code=400, detail="Sepal length must be positive") # Make prediction prediction_array = model.predict([[data.sepal_length]]) prediction_value = prediction_array[0] # Map prediction to species name species_mapping = { 0: "setosa", 1: "versicolor", 2: "virginica" } predicted_species = species_mapping.get(prediction_value, "unknown") # Get prediction probabilities if available try: probabilities = model.predict_proba([[data.sepal_length]])[0] confidence = max(probabilities) except: confidence = 0.95 # Default confidence if probabilities not available # Species information species_info = { "setosa": { "description": "Iris setosa is a species in the genus Iris, it is also in the subgenus Limniris and in the series Tripetalae.", "characteristics": "Small flowers, narrow petals, grows in wet areas", "color": "Usually blue to purple" }, "versicolor": { "description": "Iris versicolor is a species of Iris native to North America, in the Eastern United States and Eastern Canada.", "characteristics": "Medium-sized flowers, wider petals, grows in wetlands", "color": "Blue to purple with yellow markings" }, "virginica": { "description": "Iris virginica is a perennial plant species of the genus Iris, part of the subgenus Limniris and in the series Laevigatae.", "characteristics": "Large flowers, broad petals, grows in wet areas", "color": "Blue to purple, sometimes white" } } return IrisPredictionResponse( prediction=predicted_species, confidence=round(confidence, 3), sepal_length=data.sepal_length, species_info=species_info.get(predicted_species, {}) ) except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") @app.post("/predict") async def legacy_predict(data: IrisPredictionRequest): """Legacy endpoint for backward compatibility""" result = await predict_iris(data) return {"prediction": result.prediction} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)