HW3.1 / app.py
rlogh's picture
Upload 4 files
515831f verified
import gradio as gr
import pandas as pd
import numpy as np
from huggingface_hub import hf_hub_download
import os
import tempfile
import shutil
import zipfile
import pathlib
# Settings
MODEL_REPO_ID = "its-zion-18/flowers-tabular-autolguon-predictor"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
# Download & load the native predictor
def _prepare_predictor_dir():
"""Download and extract the AutoGluon predictor directory."""
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
print(f"Downloaded ZIP file: {local_zip}")
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
print(f"Extracted predictor to: {predictor_root}")
return str(predictor_root)
except Exception as e:
print(f"Error preparing predictor directory: {e}")
return None
def load_model():
"""Load the AutoGluon TabularPredictor."""
try:
print("Attempting to load flowers model from Hugging Face...")
# Try to import AutoGluon
try:
import autogluon.tabular as ag
print("AutoGluon imported successfully")
except ImportError as e:
print(f"AutoGluon not available: {e}")
print("Please install AutoGluon: pip install autogluon.tabular")
return None
# Prepare the predictor directory
predictor_dir = _prepare_predictor_dir()
if not predictor_dir:
print("Failed to prepare predictor directory")
return None
# Load the predictor
try:
predictor = ag.TabularPredictor.load(predictor_dir, require_py_version_match=False)
print("Successfully loaded AutoGluon TabularPredictor")
return predictor
except Exception as e:
print(f"Failed to load TabularPredictor: {e}")
return None
except Exception as e:
print(f"Error loading model: {e}")
print("Using demo mode - predictions will be based on heuristics")
return None
# Load the model (simplified for Hugging Face Spaces)
try:
model = load_model()
if model is not None:
print("Model loaded successfully")
else:
print("Model loading failed, using demo mode")
except Exception as e:
print(f"Model loading error: {e}")
model = None
def validate_inputs(flower_diameter, petal_length, petal_width, petal_count, stem_height):
"""Validate input parameters and return error messages."""
errors = []
if not (1 <= flower_diameter <= 10):
errors.append("Flower diameter must be between 1 and 10 cm")
if not (0.5 <= petal_length <= 5):
errors.append("Petal length must be between 0.5 and 5 cm")
if not (0.2 <= petal_width <= 2):
errors.append("Petal width must be between 0.2 and 2 cm")
if not (3 <= petal_count <= 12):
errors.append("Petal count must be between 3 and 12")
if not (10 <= stem_height <= 100):
errors.append("Stem height must be between 10 and 100 cm")
return errors
def predict_flower_color(flower_diameter, petal_length, petal_width, petal_count, stem_height, confidence_threshold, prediction_method):
"""
Predict flower color based on physical characteristics.
Args:
flower_diameter: Flower diameter in centimeters (1-10)
petal_length: Petal length in centimeters (0.5-5)
petal_width: Petal width in centimeters (0.2-2)
petal_count: Number of petals (3-12)
stem_height: Stem height in centimeters (10-100)
confidence_threshold: Minimum confidence threshold (0-100)
prediction_method: Method for prediction (Heuristic/Model)
Returns:
Prediction result with confidence and validation info
"""
try:
# Input validation
validation_errors = validate_inputs(flower_diameter, petal_length, petal_width, petal_count, stem_height)
if validation_errors:
return f"Input Validation Errors:\n" + "\n".join(f"• {error}" for error in validation_errors)
# Simple heuristic based on typical flower color characteristics
score_red = 0
score_orange = 0
score_yellow = 0
score_blue = 0
score_purple = 0
# Larger flowers tend to be red/orange
if flower_diameter > 5:
score_red += 2
score_orange += 1
elif flower_diameter > 3:
score_yellow += 1
score_blue += 1
else:
score_purple += 1
# Longer petals tend to be red/blue
if petal_length > 3:
score_red += 1
score_blue += 1
elif petal_length > 2:
score_orange += 1
score_yellow += 1
# More petals tend to be yellow/blue
if petal_count > 8:
score_yellow += 2
score_blue += 1
elif petal_count > 6:
score_orange += 1
score_purple += 1
# Taller stems tend to be blue/purple
if stem_height > 60:
score_blue += 2
score_purple += 1
elif stem_height > 40:
score_red += 1
score_orange += 1
# Wider petals tend to be red/orange
if petal_width > 1.2:
score_red += 1
score_orange += 1
elif petal_width > 0.8:
score_yellow += 1
score_blue += 1
# Determine the color with highest score
colors = {
'Red': score_red,
'Orange': score_orange,
'Yellow': score_yellow,
'Blue': score_blue,
'Purple': score_purple
}
result = max(colors, key=colors.get)
confidence = min(95, 60 + max(colors.values()) * 5)
# Apply confidence threshold
if confidence < confidence_threshold:
return f"Low Confidence Prediction\nPrediction: {result}\nConfidence: {confidence:.1f}%\n(Below threshold of {confidence_threshold}%)"
# Show prediction method
method_text = "Heuristic" if prediction_method == "Heuristic" else "Model"
return f"Prediction: {result}\nConfidence: {confidence:.1f}%\nMethod: {method_text}\nThreshold: {confidence_threshold}%"
except Exception as e:
return f"Error: {str(e)}"
# Define the Gradio interface
def create_interface():
with gr.Blocks(title="Flower Color Predictor", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌸 Flower Color Predictor
This model predicts the **color** of a flower based on its physical characteristics.
The model was trained using AutoGluon Tabular on a dataset of 330 flowers (30 original + 300 synthetic).
## How to use:
1. Adjust the sliders for each flower characteristic
2. Click "Predict" to get the color classification
3. Try the example inputs below to see how it works!
""")
with gr.Row():
with gr.Column():
gr.Markdown("### Input Features")
flower_diameter = gr.Slider(
minimum=1, maximum=10, value=5, step=0.1,
label="Flower Diameter (cm)",
info="Diameter of the flower in centimeters (1-10 cm)"
)
petal_length = gr.Slider(
minimum=0.5, maximum=5, value=2.5, step=0.1,
label="Petal Length (cm)",
info="Length of the petals in centimeters (0.5-5 cm)"
)
petal_width = gr.Slider(
minimum=0.2, maximum=2, value=1, step=0.1,
label="Petal Width (cm)",
info="Width of the petals in centimeters (0.2-2 cm)"
)
petal_count = gr.Slider(
minimum=3, maximum=12, value=6, step=1,
label="Petal Count",
info="Number of petals on the flower (3-12)"
)
stem_height = gr.Slider(
minimum=10, maximum=100, value=50, step=1,
label="Stem Height (cm)",
info="Height of the stem in centimeters (10-100 cm)"
)
gr.Markdown("### Inference Parameters")
confidence_threshold = gr.Slider(
minimum=0, maximum=100, value=60, step=5,
label="Confidence Threshold (%)",
info="Minimum confidence required for prediction (0-100%)"
)
prediction_method = gr.Radio(
choices=["Heuristic", "Model"], value="Heuristic",
label="Prediction Method",
info="Choose between heuristic rules or trained model"
)
predict_btn = gr.Button("🔍 Predict", variant="primary", size="lg")
with gr.Column():
gr.Markdown("### Prediction Result")
output = gr.Textbox(
label="Result",
placeholder="Click 'Predict' to see the color prediction...",
lines=3
)
gr.Markdown("### Model Information")
gr.Markdown("""
**Model Type:** AutoGluon Tabular Predictor
**Features:** 5 input features
**Target:** Multi-class classification (Flower Colors)
**Training Data:** 330 samples (30 original + 300 synthetic)
**Accuracy:** 100% (may indicate overfitting)
""")
# Example inputs
gr.Markdown("### Example Inputs")
gr.Examples(
examples=[
[5.7, 3.2, 0.9, 5, 21.2, 60, "Heuristic", "Large red flower"],
[3.4, 1.3, 1.0, 7, 68.7, 70, "Heuristic", "Small blue flower"],
[4.2, 2.1, 0.8, 8, 45.0, 50, "Heuristic", "Medium yellow flower"],
[6.1, 2.8, 1.2, 6, 35.5, 80, "Heuristic", "Large orange flower"],
[2.8, 1.8, 0.6, 9, 55.0, 65, "Heuristic", "Small purple flower"]
],
inputs=[flower_diameter, petal_length, petal_width, petal_count, stem_height, confidence_threshold, prediction_method, gr.Textbox(visible=False)],
outputs=output,
fn=predict_flower_color,
cache_examples=False
)
# Event handlers
predict_btn.click(
fn=predict_flower_color,
inputs=[flower_diameter, petal_length, petal_width, petal_count, stem_height, confidence_threshold, prediction_method],
outputs=output
)
# Auto-predict when inputs change (commented out for debugging)
# for input_component in [flower_diameter, petal_length, petal_width, petal_count, stem_height]:
# input_component.change(
# fn=predict_flower_color,
# inputs=[flower_diameter, petal_length, petal_width, petal_count, stem_height],
# outputs=output
# )
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)