|
|
import os |
|
|
import shutil |
|
|
import zipfile |
|
|
import pathlib |
|
|
import tempfile |
|
|
|
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from PIL import Image |
|
|
|
|
|
import huggingface_hub |
|
|
|
|
|
|
|
|
try: |
|
|
import autogluon.multimodal |
|
|
AUTOGLUON_AVAILABLE = True |
|
|
except ImportError: |
|
|
AUTOGLUON_AVAILABLE = False |
|
|
print("AutoGluon not available, using demo mode") |
|
|
|
|
|
|
|
|
MODEL_REPO_ID = "its-zion-18/sign-image-autogluon-predictor" |
|
|
ZIP_FILENAME = "autogluon_image_predictor_dir.zip" |
|
|
HF_TOKEN = os.getenv("HF_TOKEN", None) |
|
|
|
|
|
|
|
|
CACHE_DIR = pathlib.Path("hf_assets") |
|
|
EXTRACT_DIR = CACHE_DIR / "predictor_native" |
|
|
|
|
|
|
|
|
def _prepare_predictor_dir(): |
|
|
"""Download and extract the AutoGluon predictor directory.""" |
|
|
try: |
|
|
CACHE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
local_zip = huggingface_hub.hf_hub_download( |
|
|
repo_id=MODEL_REPO_ID, |
|
|
filename=ZIP_FILENAME, |
|
|
repo_type="model", |
|
|
token=HF_TOKEN, |
|
|
local_dir=str(CACHE_DIR), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
if EXTRACT_DIR.exists(): |
|
|
shutil.rmtree(EXTRACT_DIR) |
|
|
EXTRACT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
with zipfile.ZipFile(local_zip, "r") as zf: |
|
|
zf.extractall(str(EXTRACT_DIR)) |
|
|
contents = list(EXTRACT_DIR.iterdir()) |
|
|
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR |
|
|
return str(predictor_root) |
|
|
except Exception as e: |
|
|
print(f"Error preparing predictor directory: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
print("Starting app in fast mode...") |
|
|
PREDICTOR = None |
|
|
PREDICTOR_LOADED = False |
|
|
|
|
|
|
|
|
CLASS_LABELS = { |
|
|
0: "Not a Stop Sign", |
|
|
1: "Stop Sign" |
|
|
} |
|
|
|
|
|
def get_human_label(prediction): |
|
|
"""Convert model prediction to human-readable label.""" |
|
|
try: |
|
|
|
|
|
pred_value = int(prediction) |
|
|
return CLASS_LABELS.get(pred_value, f"Unknown Class ({prediction})") |
|
|
except (ValueError, TypeError): |
|
|
return f"Invalid Prediction ({prediction})" |
|
|
|
|
|
def load_model_lazy(): |
|
|
"""Load the model only when needed to avoid startup timeout.""" |
|
|
global PREDICTOR, PREDICTOR_LOADED |
|
|
|
|
|
if PREDICTOR_LOADED: |
|
|
return PREDICTOR |
|
|
|
|
|
if not AUTOGLUON_AVAILABLE: |
|
|
print("AutoGluon not available - cannot load model") |
|
|
PREDICTOR_LOADED = True |
|
|
return None |
|
|
|
|
|
try: |
|
|
print("Loading AutoGluon model from Hugging Face...") |
|
|
PREDICTOR_DIR = _prepare_predictor_dir() |
|
|
if PREDICTOR_DIR: |
|
|
print(f"Loading predictor from: {PREDICTOR_DIR}") |
|
|
PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR) |
|
|
print("✅ Model loaded successfully!") |
|
|
print(f"Model type: {type(PREDICTOR)}") |
|
|
else: |
|
|
PREDICTOR = None |
|
|
print("❌ Could not prepare model directory") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {e}") |
|
|
PREDICTOR = None |
|
|
print("Model loading failed - predictions will not be available") |
|
|
|
|
|
PREDICTOR_LOADED = True |
|
|
return PREDICTOR |
|
|
|
|
|
def predict_sign(image, confidence_threshold, preprocessing_option): |
|
|
"""Predict sign type from image.""" |
|
|
try: |
|
|
if image is None: |
|
|
return "No image uploaded", None, None |
|
|
|
|
|
|
|
|
if not isinstance(confidence_threshold, (int, float)) or confidence_threshold < 0 or confidence_threshold > 100: |
|
|
confidence_threshold = 70 |
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
|
|
|
original_image = image.copy() |
|
|
|
|
|
|
|
|
preprocessed_display = image.resize((224, 224)) |
|
|
|
|
|
|
|
|
model = load_model_lazy() |
|
|
if model is not None: |
|
|
print("Using AutoGluon model for prediction...") |
|
|
try: |
|
|
|
|
|
tmpdir = pathlib.Path(tempfile.mkdtemp()) |
|
|
img_path = tmpdir / "input.png" |
|
|
image.save(img_path) |
|
|
|
|
|
|
|
|
df = pd.DataFrame({"image": [str(img_path)]}) |
|
|
print(f"Created DataFrame with image path: {img_path}") |
|
|
|
|
|
|
|
|
print("Getting prediction from model...") |
|
|
prediction = model.predict(df) |
|
|
raw_prediction = prediction.iloc[0] |
|
|
predicted_class = get_human_label(raw_prediction) |
|
|
print(f"Raw prediction: {raw_prediction}") |
|
|
print(f"Human label: {predicted_class}") |
|
|
|
|
|
|
|
|
print("Getting prediction probabilities...") |
|
|
proba_df = model.predict_proba(df) |
|
|
confidence = float(proba_df.iloc[0].max()) * 100 |
|
|
print(f"Confidence: {confidence:.1f}%") |
|
|
|
|
|
|
|
|
pred_value = int(raw_prediction) |
|
|
if pred_value in proba_df.columns: |
|
|
class_confidence = float(proba_df.iloc[0][pred_value]) * 100 |
|
|
print(f"Class {pred_value} confidence: {class_confidence:.1f}%") |
|
|
|
|
|
|
|
|
if not isinstance(confidence, (int, float)) or confidence < 0 or confidence > 100: |
|
|
confidence = 50.0 |
|
|
|
|
|
|
|
|
try: |
|
|
shutil.rmtree(tmpdir) |
|
|
except Exception as cleanup_error: |
|
|
print(f"Warning: Could not clean up temp directory: {cleanup_error}") |
|
|
|
|
|
|
|
|
if confidence < confidence_threshold: |
|
|
result = f"⚠️ Low Confidence Prediction\nPrediction: {predicted_class}\nConfidence: {confidence:.1f}%\n(Threshold: {confidence_threshold}%)\nRaw Output: {raw_prediction}\nMethod: AutoGluon Model" |
|
|
else: |
|
|
result = f"✅ Prediction: {predicted_class}\nConfidence: {confidence:.1f}%\nRaw Output: {raw_prediction}\nMethod: AutoGluon Model" |
|
|
|
|
|
if preprocessing_option: |
|
|
return result, original_image, preprocessed_display |
|
|
else: |
|
|
return result, original_image, None |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error with model prediction: {e}") |
|
|
result = f"❌ Prediction Error\nModel prediction failed: {str(e)}\nMethod: Error" |
|
|
if preprocessing_option: |
|
|
return result, original_image, preprocessed_display |
|
|
else: |
|
|
return result, original_image, None |
|
|
|
|
|
|
|
|
print("Model not available, cannot make prediction") |
|
|
result = "❌ Model Error\nUnable to load the trained model.\nPlease try again or contact support.\nMethod: Error" |
|
|
|
|
|
if preprocessing_option: |
|
|
return result, original_image, preprocessed_display |
|
|
else: |
|
|
return result, original_image, None |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", None, None |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Sign Image Classifier") as demo: |
|
|
gr.Markdown("# Sign Image Classifier") |
|
|
gr.Markdown("Upload an image containing a sign to classify it.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
image_input = gr.Image( |
|
|
type="pil", |
|
|
label="Upload Sign Image (PNG or JPG)", |
|
|
sources=["upload", "webcam"] |
|
|
) |
|
|
|
|
|
confidence_threshold = gr.Slider( |
|
|
minimum=0, maximum=100, value=70, step=5, |
|
|
label="Confidence Threshold (%)" |
|
|
) |
|
|
|
|
|
preprocessing_option = gr.Checkbox( |
|
|
value=True, |
|
|
label="Show Preprocessing" |
|
|
) |
|
|
|
|
|
classify_btn = gr.Button("Classify Sign", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_text = gr.Textbox( |
|
|
label="Prediction Result", |
|
|
value="Upload an image and click 'Classify Sign' to see the prediction...", |
|
|
lines=6, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Original Image"): |
|
|
original_display = gr.Image( |
|
|
label="Original Image", |
|
|
type="pil", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Tab("Preprocessed Image"): |
|
|
preprocessed_display = gr.Image( |
|
|
label="Preprocessed Image (Model Input)", |
|
|
type="pil", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### Example Images") |
|
|
try: |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["https://www.myparkingsign.com/img/lg2/K/k2-4958-2.png", 70, True, "Stop Sign"], |
|
|
["https://res.cloudinary.com/grimcoweb/image/upload/c_limit,f_auto,q_auto,w_500/v1608017423/Catalog/ProductImages/speedlimitsignproduct-image.jpg", 80, True, "Speed Limit Sign"], |
|
|
["https://cdn11.bigcommerce.com/s-4nops3qe/images/stencil/1280x1280/products/14450/18972/street-signs__99115.1511199912.jpg?c=2", 60, True, "Street Sign"] |
|
|
], |
|
|
inputs=[image_input, confidence_threshold, preprocessing_option, gr.Textbox(visible=False)], |
|
|
outputs=[output_text, original_display, preprocessed_display], |
|
|
fn=predict_sign, |
|
|
cache_examples=False, |
|
|
label="Try these example signs:" |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load examples: {e}") |
|
|
gr.Markdown("Example images temporarily unavailable.") |
|
|
|
|
|
|
|
|
classify_btn.click( |
|
|
fn=predict_sign, |
|
|
inputs=[image_input, confidence_threshold, preprocessing_option], |
|
|
outputs=[output_text, original_display, preprocessed_display] |
|
|
) |
|
|
|
|
|
image_input.change( |
|
|
fn=predict_sign, |
|
|
inputs=[image_input, confidence_threshold, preprocessing_option], |
|
|
outputs=[output_text, original_display, preprocessed_display] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |