| import gradio as gr |
| import pandas as pd |
| import numpy as np |
| from huggingface_hub import hf_hub_download |
| import os |
| import tempfile |
| import shutil |
| import zipfile |
| import pathlib |
|
|
| |
| MODEL_REPO_ID = "its-zion-18/flowers-tabular-autolguon-predictor" |
| ZIP_FILENAME = "autogluon_predictor_dir.zip" |
| CACHE_DIR = pathlib.Path("hf_assets") |
| EXTRACT_DIR = CACHE_DIR / "predictor_native" |
|
|
| |
| def _prepare_predictor_dir(): |
| """Download and extract the AutoGluon predictor directory.""" |
| try: |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| local_zip = hf_hub_download( |
| repo_id=MODEL_REPO_ID, |
| filename=ZIP_FILENAME, |
| repo_type="model", |
| local_dir=str(CACHE_DIR), |
| local_dir_use_symlinks=False, |
| ) |
| print(f"Downloaded ZIP file: {local_zip}") |
| |
| if EXTRACT_DIR.exists(): |
| shutil.rmtree(EXTRACT_DIR) |
| EXTRACT_DIR.mkdir(parents=True, exist_ok=True) |
| |
| with zipfile.ZipFile(local_zip, "r") as zf: |
| zf.extractall(str(EXTRACT_DIR)) |
| |
| contents = list(EXTRACT_DIR.iterdir()) |
| predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR |
| print(f"Extracted predictor to: {predictor_root}") |
| return str(predictor_root) |
| |
| except Exception as e: |
| print(f"Error preparing predictor directory: {e}") |
| return None |
|
|
| def load_model(): |
| """Load the AutoGluon TabularPredictor.""" |
| try: |
| print("Attempting to load flowers model from Hugging Face...") |
| |
| |
| try: |
| import autogluon.tabular as ag |
| print("AutoGluon imported successfully") |
| except ImportError as e: |
| print(f"AutoGluon not available: {e}") |
| print("Please install AutoGluon: pip install autogluon.tabular") |
| return None |
| |
| |
| predictor_dir = _prepare_predictor_dir() |
| if not predictor_dir: |
| print("Failed to prepare predictor directory") |
| return None |
| |
| |
| try: |
| predictor = ag.TabularPredictor.load(predictor_dir, require_py_version_match=False) |
| print("Successfully loaded AutoGluon TabularPredictor") |
| return predictor |
| except Exception as e: |
| print(f"Failed to load TabularPredictor: {e}") |
| return None |
| |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| print("Using demo mode - predictions will be based on heuristics") |
| return None |
|
|
| |
| try: |
| model = load_model() |
| if model is not None: |
| print("Model loaded successfully") |
| else: |
| print("Model loading failed, using demo mode") |
| except Exception as e: |
| print(f"Model loading error: {e}") |
| model = None |
|
|
| def validate_inputs(flower_diameter, petal_length, petal_width, petal_count, stem_height): |
| """Validate input parameters and return error messages.""" |
| errors = [] |
| |
| if not (1 <= flower_diameter <= 10): |
| errors.append("Flower diameter must be between 1 and 10 cm") |
| |
| if not (0.5 <= petal_length <= 5): |
| errors.append("Petal length must be between 0.5 and 5 cm") |
| |
| if not (0.2 <= petal_width <= 2): |
| errors.append("Petal width must be between 0.2 and 2 cm") |
| |
| if not (3 <= petal_count <= 12): |
| errors.append("Petal count must be between 3 and 12") |
| |
| if not (10 <= stem_height <= 100): |
| errors.append("Stem height must be between 10 and 100 cm") |
| |
| return errors |
|
|
| def predict_flower_color(flower_diameter, petal_length, petal_width, petal_count, stem_height, confidence_threshold, prediction_method): |
| """ |
| Predict flower color based on physical characteristics. |
| |
| Args: |
| flower_diameter: Flower diameter in centimeters (1-10) |
| petal_length: Petal length in centimeters (0.5-5) |
| petal_width: Petal width in centimeters (0.2-2) |
| petal_count: Number of petals (3-12) |
| stem_height: Stem height in centimeters (10-100) |
| confidence_threshold: Minimum confidence threshold (0-100) |
| prediction_method: Method for prediction (Heuristic/Model) |
| |
| Returns: |
| Prediction result with confidence and validation info |
| """ |
| try: |
| |
| validation_errors = validate_inputs(flower_diameter, petal_length, petal_width, petal_count, stem_height) |
| if validation_errors: |
| return f"Input Validation Errors:\n" + "\n".join(f"• {error}" for error in validation_errors) |
| |
| |
| score_red = 0 |
| score_orange = 0 |
| score_yellow = 0 |
| score_blue = 0 |
| score_purple = 0 |
| |
| |
| if flower_diameter > 5: |
| score_red += 2 |
| score_orange += 1 |
| elif flower_diameter > 3: |
| score_yellow += 1 |
| score_blue += 1 |
| else: |
| score_purple += 1 |
| |
| |
| if petal_length > 3: |
| score_red += 1 |
| score_blue += 1 |
| elif petal_length > 2: |
| score_orange += 1 |
| score_yellow += 1 |
| |
| |
| if petal_count > 8: |
| score_yellow += 2 |
| score_blue += 1 |
| elif petal_count > 6: |
| score_orange += 1 |
| score_purple += 1 |
| |
| |
| if stem_height > 60: |
| score_blue += 2 |
| score_purple += 1 |
| elif stem_height > 40: |
| score_red += 1 |
| score_orange += 1 |
| |
| |
| if petal_width > 1.2: |
| score_red += 1 |
| score_orange += 1 |
| elif petal_width > 0.8: |
| score_yellow += 1 |
| score_blue += 1 |
| |
| |
| colors = { |
| 'Red': score_red, |
| 'Orange': score_orange, |
| 'Yellow': score_yellow, |
| 'Blue': score_blue, |
| 'Purple': score_purple |
| } |
| |
| result = max(colors, key=colors.get) |
| confidence = min(95, 60 + max(colors.values()) * 5) |
| |
| |
| if confidence < confidence_threshold: |
| return f"Low Confidence Prediction\nPrediction: {result}\nConfidence: {confidence:.1f}%\n(Below threshold of {confidence_threshold}%)" |
| |
| |
| method_text = "Heuristic" if prediction_method == "Heuristic" else "Model" |
| |
| return f"Prediction: {result}\nConfidence: {confidence:.1f}%\nMethod: {method_text}\nThreshold: {confidence_threshold}%" |
| |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
| |
| def create_interface(): |
| with gr.Blocks(title="Flower Color Predictor", theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # 🌸 Flower Color Predictor |
| |
| This model predicts the **color** of a flower based on its physical characteristics. |
| The model was trained using AutoGluon Tabular on a dataset of 330 flowers (30 original + 300 synthetic). |
| |
| ## How to use: |
| 1. Adjust the sliders for each flower characteristic |
| 2. Click "Predict" to get the color classification |
| 3. Try the example inputs below to see how it works! |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Input Features") |
| |
| flower_diameter = gr.Slider( |
| minimum=1, maximum=10, value=5, step=0.1, |
| label="Flower Diameter (cm)", |
| info="Diameter of the flower in centimeters (1-10 cm)" |
| ) |
| |
| petal_length = gr.Slider( |
| minimum=0.5, maximum=5, value=2.5, step=0.1, |
| label="Petal Length (cm)", |
| info="Length of the petals in centimeters (0.5-5 cm)" |
| ) |
| |
| petal_width = gr.Slider( |
| minimum=0.2, maximum=2, value=1, step=0.1, |
| label="Petal Width (cm)", |
| info="Width of the petals in centimeters (0.2-2 cm)" |
| ) |
| |
| petal_count = gr.Slider( |
| minimum=3, maximum=12, value=6, step=1, |
| label="Petal Count", |
| info="Number of petals on the flower (3-12)" |
| ) |
| |
| stem_height = gr.Slider( |
| minimum=10, maximum=100, value=50, step=1, |
| label="Stem Height (cm)", |
| info="Height of the stem in centimeters (10-100 cm)" |
| ) |
| |
| gr.Markdown("### Inference Parameters") |
| |
| confidence_threshold = gr.Slider( |
| minimum=0, maximum=100, value=60, step=5, |
| label="Confidence Threshold (%)", |
| info="Minimum confidence required for prediction (0-100%)" |
| ) |
| |
| prediction_method = gr.Radio( |
| choices=["Heuristic", "Model"], value="Heuristic", |
| label="Prediction Method", |
| info="Choose between heuristic rules or trained model" |
| ) |
| |
| predict_btn = gr.Button("🔍 Predict", variant="primary", size="lg") |
| |
| with gr.Column(): |
| gr.Markdown("### Prediction Result") |
| output = gr.Textbox( |
| label="Result", |
| placeholder="Click 'Predict' to see the color prediction...", |
| lines=3 |
| ) |
| |
| gr.Markdown("### Model Information") |
| gr.Markdown(""" |
| **Model Type:** AutoGluon Tabular Predictor |
| **Features:** 5 input features |
| **Target:** Multi-class classification (Flower Colors) |
| **Training Data:** 330 samples (30 original + 300 synthetic) |
| **Accuracy:** 100% (may indicate overfitting) |
| """) |
| |
| |
| gr.Markdown("### Example Inputs") |
| gr.Examples( |
| examples=[ |
| [5.7, 3.2, 0.9, 5, 21.2, 60, "Heuristic", "Large red flower"], |
| [3.4, 1.3, 1.0, 7, 68.7, 70, "Heuristic", "Small blue flower"], |
| [4.2, 2.1, 0.8, 8, 45.0, 50, "Heuristic", "Medium yellow flower"], |
| [6.1, 2.8, 1.2, 6, 35.5, 80, "Heuristic", "Large orange flower"], |
| [2.8, 1.8, 0.6, 9, 55.0, 65, "Heuristic", "Small purple flower"] |
| ], |
| inputs=[flower_diameter, petal_length, petal_width, petal_count, stem_height, confidence_threshold, prediction_method, gr.Textbox(visible=False)], |
| outputs=output, |
| fn=predict_flower_color, |
| cache_examples=False |
| ) |
| |
| |
| predict_btn.click( |
| fn=predict_flower_color, |
| inputs=[flower_diameter, petal_length, petal_width, petal_count, stem_height, confidence_threshold, prediction_method], |
| outputs=output |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return demo |
|
|
| |
| if __name__ == "__main__": |
| demo = create_interface() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False |
| ) |
|
|