File size: 3,642 Bytes
7a4b7bd
 
 
 
14395f1
 
 
7a4b7bd
14395f1
 
 
7a4b7bd
 
 
 
 
 
 
 
 
 
14395f1
7a4b7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba3c608
7a4b7bd
14395f1
7a4b7bd
 
 
 
14395f1
7a4b7bd
14395f1
 
 
7a4b7bd
14395f1
7a4b7bd
14395f1
7a4b7bd
 
14395f1
 
7a4b7bd
bd6b66d
7a4b7bd
 
bd6b66d
7a4b7bd
bd6b66d
14395f1
7a4b7bd
14395f1
7a4b7bd
 
 
14395f1
7a4b7bd
 
 
14395f1
7a4b7bd
 
14395f1
7a4b7bd
 
14395f1
7a4b7bd
14395f1
 
7a4b7bd
14395f1
7a4b7bd
14395f1
 
 
7a4b7bd
 
 
 
 
14395f1
7a4b7bd
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

import gradio as gr
from PIL import Image
import pathlib
import os
import shutil
import zipfile
import huggingface_hub
from huggingface_hub import hf_hub_download
from autogluon.multimodal import MultiModalPredictor
import pandas as pd

# Model config
MODEL_REPO_ID = "cassieli226/sign-identification-automl"
ZIP_FILENAME = "autogluon_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"

# Download & load the native predictor
def _prepare_predictor_dir() -> str:
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
    return str(predictor_root)

PREDICTOR_DIR = _prepare_predictor_dir()
PREDICTOR = MultiModalPredictor.load(PREDICTOR_DIR)

clf = PREDICTOR  # MultiModalPredictor

def predict_image(image, confidence_threshold=0.5, return_probabilities=False):
    if image is None:
        return "Please upload an image", ""

    try:
        # Wrap the image in a dict for MultiModalPredictor
        input_df = pd.DataFrame([{"image": image}])

        # Make prediction
        prediction = clf.predict(input_df)[0]
        detailed = f"**Predicted Class:** {prediction}"

        if return_probabilities:
            try:
                probs = clf.predict_proba(input_df)
                prob_dict = probs.iloc[0].to_dict()
                sorted_probs = sorted(prob_dict.items(), key=lambda x: x[1], reverse=True)
                detailed += "\n\n**Prediction Probabilities:**\n"
                for class_name, prob in sorted_probs[:5]:
                    if prob >= confidence_threshold:
                        detailed += f"- {class_name}: {prob:.3f}\n"
            except Exception as e:
                detailed += f"\n\nNote: Could not retrieve probabilities ({str(e)})"

        return prediction, detailed

    except Exception as e:
        return f"Error: {str(e)}", f"**Error:** {str(e)}"

# --- Gradio UI ---
with gr.Blocks(title="Image Classifier", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🖼️ AI Image Classifier

    ## Upload an image to get AI-powered classification results
    """)

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="Upload Image", height=300)
            gr.Markdown("### 🔧 Inference Parameters")
            confidence_threshold = gr.Slider(0.0, 1.0, value=0.1, step=0.05, label="Confidence Threshold")
            return_probabilities = gr.Checkbox(label="Show Prediction Probabilities", value=True)
            predict_btn = gr.Button("🎯 Classify Image", variant="primary", size="lg")

        with gr.Column():
            prediction_output = gr.Textbox(label="Prediction Result", interactive=False, lines=1)
            detailed_output = gr.Markdown(label="Detailed Analysis", value="Results will appear here...")

    predict_btn.click(
        fn=predict_image,
        inputs=[image_input, confidence_threshold, return_probabilities],
        outputs=[prediction_output, detailed_output]
    )

    gr.Markdown("""
    ---
    *Built with AutoGluon and Gradio*
    """)

if __name__ == "__main__":
    demo.launch()