Jesudian Challapalli
Application files are added
1e3bf36
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import joblib
import os
import uvicorn
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from typing import Dict, Any
import numpy as np
app = FastAPI(
title="Iris Flower Prediction API",
description="A machine learning API for predicting iris flower species based on sepal length",
version="1.0.0"
)
# CORS Configuration for Hugging Face Spaces
origins = [
"*", # Allow all origins for Hugging Face Spaces
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the model
model_path = os.path.join(os.path.dirname(__file__), 'model/model.joblib')
model = joblib.load(model_path)
# Serve static files
app.mount("/static", StaticFiles(directory="ui"), name="static")
class IrisPredictionRequest(BaseModel):
sepal_length: float
class Config:
schema_extra = {
"example": {
"sepal_length": 5.1
}
}
class IrisPredictionResponse(BaseModel):
prediction: str
confidence: float
sepal_length: float
species_info: Dict[str, Any]
@app.get("/", response_class=HTMLResponse)
async def root():
"""Serve the main interface"""
return FileResponse("ui/index.html")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": model is not None}
@app.get("/api/info")
async def get_api_info():
"""Get API information"""
return {
"name": "Iris Flower Prediction API",
"version": "1.0.0",
"description": "Predict iris flower species based on sepal length",
"input_features": ["sepal_length"],
"output_classes": ["setosa", "versicolor", "virginica"],
"model_type": "Machine Learning Classifier"
}
@app.post("/api/predict", response_model=IrisPredictionResponse)
async def predict_iris(data: IrisPredictionRequest):
"""
Predict iris flower species based on sepal length
Args:
data: IrisPredictionRequest containing sepal_length
Returns:
IrisPredictionResponse with prediction and additional info
"""
try:
# Validate input
if data.sepal_length <= 0:
raise HTTPException(status_code=400, detail="Sepal length must be positive")
# Make prediction
prediction_array = model.predict([[data.sepal_length]])
prediction_value = prediction_array[0]
# Map prediction to species name
species_mapping = {
0: "setosa",
1: "versicolor",
2: "virginica"
}
predicted_species = species_mapping.get(prediction_value, "unknown")
# Get prediction probabilities if available
try:
probabilities = model.predict_proba([[data.sepal_length]])[0]
confidence = max(probabilities)
except:
confidence = 0.95 # Default confidence if probabilities not available
# Species information
species_info = {
"setosa": {
"description": "Iris setosa is a species in the genus Iris, it is also in the subgenus Limniris and in the series Tripetalae.",
"characteristics": "Small flowers, narrow petals, grows in wet areas",
"color": "Usually blue to purple"
},
"versicolor": {
"description": "Iris versicolor is a species of Iris native to North America, in the Eastern United States and Eastern Canada.",
"characteristics": "Medium-sized flowers, wider petals, grows in wetlands",
"color": "Blue to purple with yellow markings"
},
"virginica": {
"description": "Iris virginica is a perennial plant species of the genus Iris, part of the subgenus Limniris and in the series Laevigatae.",
"characteristics": "Large flowers, broad petals, grows in wet areas",
"color": "Blue to purple, sometimes white"
}
}
return IrisPredictionResponse(
prediction=predicted_species,
confidence=round(confidence, 3),
sepal_length=data.sepal_length,
species_info=species_info.get(predicted_species, {})
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.post("/predict")
async def legacy_predict(data: IrisPredictionRequest):
"""Legacy endpoint for backward compatibility"""
result = await predict_iris(data)
return {"prediction": result.prediction}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)