aedupuga's picture
Upload folder using huggingface_hub
9b2cbd5 verified
import random # Import random for selecting examples
import os # For reading environment variables
import shutil # For directory cleanup
import zipfile # For extracting model archives
import pathlib # For path manipulations
import tempfile # For creating temporary files/directories
import numpy as np # For image processing
# import pickle # No longer needed for loading from zip
import gradio # For interactive UI
import pandas # For tabular data handling
import PIL.Image # For image I/O
import huggingface_hub # For downloading model assets
import autogluon.multimodal # For loading AutoGluon image classifier
# Hardcoded Hub model (native zip)
MODEL_REPO_ID = "jennifee/nnl_automl_model" # Updated model ID
ZIP_FILENAME = "autogluon_image_predictor_dir.zip" # Updated filename
HF_TOKEN = os.getenv("HF_TOKEN", None)
# Local cache/extract dirs
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native" # Keep extract dir name for consistency
# Download & load the native predictor
def _prepare_predictor_dir() -> str: # Reverted function name
# Clear the Hugging Face Hub cache to avoid caching issues - Keep for now, can remove if issues persist
from huggingface_hub import delete_repo
try:
# Use the current MODEL_REPO_ID for deletion
delete_repo(MODEL_REPO_ID, repo_type="model", token=HF_TOKEN)
except Exception as e:
print(f"Could not delete repo from cache (may not exist or unauthorized): {e}")
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = huggingface_hub.hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
token=HF_TOKEN,
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
try:
PREDICTOR_DIR = _prepare_predictor_dir() # Call the function to prepare the directory
# PREDICTOR_FILE = _prepare_predictor_file() # Old call
# Load the predictor from the extracted directory
PREDICTOR = autogluon.multimodal.MultiModalPredictor.load(PREDICTOR_DIR) # Updated load
# with open(PREDICTOR_FILE, 'rb') as f: # Old pickle load
# PREDICTOR = pickle.load(f) # Old pickle load
except Exception as e:
print(f"Error loading predictor: {e}")
PREDICTOR = None # Set predictor to None if loading fails
# Explicit class labels (edit copy as desired) - Updated based on user input
CLASS_LABELS = {
0: "👑 Face",
1: "🔢 Value",
# Removed suit labels and Joker as per new mapping
# 5: "Unknown" # Keep Unknown as a placeholder if needed
}
# Helper to map model class -> human label
def _human_label(c):
try:
# Attempt to convert to integer first, then use get for safety
ci = int(c)
return CLASS_LABELS.get(ci, str(c))
except (ValueError, TypeError):
# If conversion fails, try getting directly by key
return CLASS_LABELS.get(c, str(c))
# Function to preprocess image and return processed image - Keep as is for now
def preprocess_image_for_display(pil_img: PIL.Image.Image):
if pil_img is None:
return None
# AutoGluon preprocessing (simplified, actual preprocessing is done internally by the predictor)
# Here we resize for display purposes to show a consistent "processed" image
processed_img = pil_img.resize((224, 224)) # Example size, adjust as needed
return processed_img
# Do the prediction! - Adjusting outputs to match the new model's likely output
def do_predict(pil_img: PIL.Image.Image):
# Make sure there's actually an image to work with and predictor is loaded
if pil_img is None:
# Returning None for the processed image output when input is None
# Adjusting return values to match expected outputs: status, probabilities, processed image
return "No image provided.", {}, None
if PREDICTOR is None:
# Returning None for the processed image output when predictor is not loaded
# Adjusting return values to match expected outputs: status, probabilities, processed image
return "Predictor not loaded. Please check the logs for errors.", {}, None
# Basic validation (file type is handled by Gradio, checking size)
# This is a placeholder; real size checks would be on file upload before PIL
# For now, we'll just check if the image object is valid
try:
pil_img.verify()
except Exception:
# Returning None for the processed image output for invalid image
# Adjusting return values to match expected outputs: status, probabilities, processed image
return "Invalid image file.", {}, None
# IF we have something to work with, save it and prepare the input
tmpdir = pathlib.Path(tempfile.mkdtemp())
img_path = tmpdir / "input.png"
pil_img.save(img_path)
df = pandas.DataFrame({"image": [str(img_path)]}) # For AutoGluon expected input format
# For class probabilities
# Assuming predict_proba returns a DataFrame where columns are class labels
proba_output = PREDICTOR.predict_proba(df)
print(f"Type of proba_output: {type(proba_output)}")
print(f"Content of proba_output: {proba_output}")
# Assuming proba_output is a pandas DataFrame with class probabilities
if not proba_output.empty:
# Get probabilities for the first (and likely only) row
proba_series = proba_output.iloc[0]
# Convert to dictionary, mapping original labels to probabilities
proba_dict = proba_series.to_dict()
# For user-friendly column names
# Map the original labels (keys in proba_dict) to human-friendly labels
pretty_dict = {
_human_label(k): float(v) for k, v in proba_dict.items()
}
else:
# Handle case where predict_proba returns empty
pretty_dict = {}
# Generate processed image for display
processed_img_display = preprocess_image_for_display(pil_img)
# Return prediction result, probabilities, and the processed image for display
# Ensure the number of return values matches the outputs in the Gradio interface
return "Prediction Complete", pretty_dict, processed_img_display
# Representative example images! These can be local or links.
# Using the provided local file paths as examples
EXAMPLES = [
["./examples/WhatsApp Image 2025-09-12 at 22.05.40 (2).jpeg"],
["./examples/WhatsApp Image 2025-09-12 at 22.05.40 (3).jpeg"],
["./examples/WhatsApp Image 2025-09-12 at 22.05.40 (5).jpeg"]
]
# Gradio UI
with gradio.Blocks() as demo:
# Provide an introduction - Updated for Playing Cards
gradio.Markdown("# Playing Card Detection")
gradio.Markdown("""
This is a simple app that demonstrates how to use an autogluon multimodal
predictor in a gradio space to predict the type of playing card in a picture. To use,
just upload a photo using the options below. The original and preprocessed
images will be displayed, and the prediction results will appear automatically.
""")
with gradio.Row():
# Interface for the incoming image
image_in = gradio.Image(type="pil", label="Input image", sources=["upload", "webcam", "clipboard"])
# Display preprocessed image
image_processed_out = gradio.Image(type="pil", label="Preprocessed image")
# Interface elements to show the result and probabilities
# Adjusting num_top_classes if the model has more classes
proba_pretty = gradio.Label(num_top_classes=len(CLASS_LABELS), label="Class probabilities")
prediction_status = gradio.Textbox(label="Prediction Status") # Added Textbox for status
# Expose key inference parameters (placeholder)
# gradio.Markdown("## Inference Parameters")
# with gradio.Row():
# Add parameters here if needed, e.e., a confidence threshold slider
# confidence_threshold = gradio.Slider(minimum=0, maximum=1, value=0.5, label="Confidence Threshold")
# Whenever a new image is uploaded, trigger the prediction directly
# Wrap do_predict in a lambda to ensure only the image input is passed
# Ensure outputs match the return values of do_predict
image_in.change(
fn=lambda img: do_predict(img),
inputs=[image_in],
outputs=[prediction_status, proba_pretty, image_processed_out]
)
# For clickable example images - ADDED BACK
if EXAMPLES: # Only show examples if any were successfully fetched
gradio.Examples(
examples=EXAMPLES,
inputs=[image_in],
label="Representative examples",
examples_per_page=8,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()