Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import joblib | |
| import os | |
| import uvicorn | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, HTMLResponse | |
| from typing import Dict, Any | |
| import numpy as np | |
| app = FastAPI( | |
| title="Iris Flower Prediction API", | |
| description="A machine learning API for predicting iris flower species based on sepal length", | |
| version="1.0.0" | |
| ) | |
| # CORS Configuration for Hugging Face Spaces | |
| origins = [ | |
| "*", # Allow all origins for Hugging Face Spaces | |
| ] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load the model | |
| model_path = os.path.join(os.path.dirname(__file__), 'model/model.joblib') | |
| model = joblib.load(model_path) | |
| # Serve static files | |
| app.mount("/static", StaticFiles(directory="ui"), name="static") | |
| class IrisPredictionRequest(BaseModel): | |
| sepal_length: float | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "sepal_length": 5.1 | |
| } | |
| } | |
| class IrisPredictionResponse(BaseModel): | |
| prediction: str | |
| confidence: float | |
| sepal_length: float | |
| species_info: Dict[str, Any] | |
| async def root(): | |
| """Serve the main interface""" | |
| return FileResponse("ui/index.html") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "model_loaded": model is not None} | |
| async def get_api_info(): | |
| """Get API information""" | |
| return { | |
| "name": "Iris Flower Prediction API", | |
| "version": "1.0.0", | |
| "description": "Predict iris flower species based on sepal length", | |
| "input_features": ["sepal_length"], | |
| "output_classes": ["setosa", "versicolor", "virginica"], | |
| "model_type": "Machine Learning Classifier" | |
| } | |
| async def predict_iris(data: IrisPredictionRequest): | |
| """ | |
| Predict iris flower species based on sepal length | |
| Args: | |
| data: IrisPredictionRequest containing sepal_length | |
| Returns: | |
| IrisPredictionResponse with prediction and additional info | |
| """ | |
| try: | |
| # Validate input | |
| if data.sepal_length <= 0: | |
| raise HTTPException(status_code=400, detail="Sepal length must be positive") | |
| # Make prediction | |
| prediction_array = model.predict([[data.sepal_length]]) | |
| prediction_value = prediction_array[0] | |
| # Map prediction to species name | |
| species_mapping = { | |
| 0: "setosa", | |
| 1: "versicolor", | |
| 2: "virginica" | |
| } | |
| predicted_species = species_mapping.get(prediction_value, "unknown") | |
| # Get prediction probabilities if available | |
| try: | |
| probabilities = model.predict_proba([[data.sepal_length]])[0] | |
| confidence = max(probabilities) | |
| except: | |
| confidence = 0.95 # Default confidence if probabilities not available | |
| # Species information | |
| species_info = { | |
| "setosa": { | |
| "description": "Iris setosa is a species in the genus Iris, it is also in the subgenus Limniris and in the series Tripetalae.", | |
| "characteristics": "Small flowers, narrow petals, grows in wet areas", | |
| "color": "Usually blue to purple" | |
| }, | |
| "versicolor": { | |
| "description": "Iris versicolor is a species of Iris native to North America, in the Eastern United States and Eastern Canada.", | |
| "characteristics": "Medium-sized flowers, wider petals, grows in wetlands", | |
| "color": "Blue to purple with yellow markings" | |
| }, | |
| "virginica": { | |
| "description": "Iris virginica is a perennial plant species of the genus Iris, part of the subgenus Limniris and in the series Laevigatae.", | |
| "characteristics": "Large flowers, broad petals, grows in wet areas", | |
| "color": "Blue to purple, sometimes white" | |
| } | |
| } | |
| return IrisPredictionResponse( | |
| prediction=predicted_species, | |
| confidence=round(confidence, 3), | |
| sepal_length=data.sepal_length, | |
| species_info=species_info.get(predicted_species, {}) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |
| async def legacy_predict(data: IrisPredictionRequest): | |
| """Legacy endpoint for backward compatibility""" | |
| result = await predict_iris(data) | |
| return {"prediction": result.prediction} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |