File size: 4,949 Bytes
1e3bf36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import os
import uvicorn
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from typing import Dict, Any
import numpy as np

app = FastAPI(
    title="Iris Flower Prediction API",
    description="A machine learning API for predicting iris flower species based on sepal length",
    version="1.0.0"
)

# CORS Configuration for Hugging Face Spaces
origins = [
    "*",  # Allow all origins for Hugging Face Spaces
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load the model
model_path = os.path.join(os.path.dirname(__file__), 'model/model.joblib')
model = joblib.load(model_path)

# Serve static files
app.mount("/static", StaticFiles(directory="ui"), name="static")

class IrisPredictionRequest(BaseModel):
    sepal_length: float
    
    class Config:
        schema_extra = {
            "example": {
                "sepal_length": 5.1
            }
        }

class IrisPredictionResponse(BaseModel):
    prediction: str
    confidence: float
    sepal_length: float
    species_info: Dict[str, Any]

@app.get("/", response_class=HTMLResponse)
async def root():
    """Serve the main interface"""
    return FileResponse("ui/index.html")

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy", "model_loaded": model is not None}

@app.get("/api/info")
async def get_api_info():
    """Get API information"""
    return {
        "name": "Iris Flower Prediction API",
        "version": "1.0.0",
        "description": "Predict iris flower species based on sepal length",
        "input_features": ["sepal_length"],
        "output_classes": ["setosa", "versicolor", "virginica"],
        "model_type": "Machine Learning Classifier"
    }

@app.post("/api/predict", response_model=IrisPredictionResponse)
async def predict_iris(data: IrisPredictionRequest):
    """
    Predict iris flower species based on sepal length
    
    Args:
        data: IrisPredictionRequest containing sepal_length
        
    Returns:
        IrisPredictionResponse with prediction and additional info
    """
    try:
        # Validate input
        if data.sepal_length <= 0:
            raise HTTPException(status_code=400, detail="Sepal length must be positive")
        
        # Make prediction
        prediction_array = model.predict([[data.sepal_length]])
        prediction_value = prediction_array[0]
        
        # Map prediction to species name
        species_mapping = {
            0: "setosa",
            1: "versicolor", 
            2: "virginica"
        }
        
        predicted_species = species_mapping.get(prediction_value, "unknown")
        
        # Get prediction probabilities if available
        try:
            probabilities = model.predict_proba([[data.sepal_length]])[0]
            confidence = max(probabilities)
        except:
            confidence = 0.95  # Default confidence if probabilities not available
        
        # Species information
        species_info = {
            "setosa": {
                "description": "Iris setosa is a species in the genus Iris, it is also in the subgenus Limniris and in the series Tripetalae.",
                "characteristics": "Small flowers, narrow petals, grows in wet areas",
                "color": "Usually blue to purple"
            },
            "versicolor": {
                "description": "Iris versicolor is a species of Iris native to North America, in the Eastern United States and Eastern Canada.",
                "characteristics": "Medium-sized flowers, wider petals, grows in wetlands",
                "color": "Blue to purple with yellow markings"
            },
            "virginica": {
                "description": "Iris virginica is a perennial plant species of the genus Iris, part of the subgenus Limniris and in the series Laevigatae.",
                "characteristics": "Large flowers, broad petals, grows in wet areas",
                "color": "Blue to purple, sometimes white"
            }
        }
        
        return IrisPredictionResponse(
            prediction=predicted_species,
            confidence=round(confidence, 3),
            sepal_length=data.sepal_length,
            species_info=species_info.get(predicted_species, {})
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")

@app.post("/predict")
async def legacy_predict(data: IrisPredictionRequest):
    """Legacy endpoint for backward compatibility"""
    result = await predict_iris(data)
    return {"prediction": result.prediction}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)