Spaces:
Sleeping
Sleeping
| import io | |
| import logging | |
| import os | |
| import time | |
| import numpy as np | |
| import streamlit as st | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image, UnidentifiedImageError | |
| from src.config import HF_FILENAME, HF_REPO_ID, IMAGE_SIZE, MODEL_PATH | |
| logger = logging.getLogger(__name__) | |
| _MAX_FILE_SIZE_MB = 10 | |
| def load_model(): | |
| """Load the ConvNeXt model, downloading from HuggingFace if not cached locally.""" | |
| # size == 0 catches a previous aborted download that left a 0-byte file | |
| if not os.path.exists(MODEL_PATH) or os.path.getsize(MODEL_PATH) == 0: | |
| logger.info("Model not found locally — downloading from HuggingFace Hub.") | |
| _download_model() | |
| start = time.time() | |
| logger.info("Loading model from %s", MODEL_PATH) | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| elapsed = time.time() - start | |
| logger.info("Model loaded in %.2fs", elapsed) | |
| return model, elapsed | |
| def _download_model() -> None: | |
| """Download the model from HuggingFace Hub using the official SDK. | |
| hf_hub_download handles resumable downloads, caching, and progress | |
| reporting automatically — replacing the manual requests.get() approach. | |
| local_dir pins the file to the project folder instead of ~/.cache/huggingface, | |
| so the app finds it at MODEL_PATH on all platforms including Docker. | |
| """ | |
| os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) | |
| logger.info("Downloading %s from repo %s", HF_FILENAME, HF_REPO_ID) | |
| hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=HF_FILENAME, | |
| local_dir=os.path.dirname(MODEL_PATH), | |
| ) | |
| logger.info("Model saved to %s", MODEL_PATH) | |
| def preprocess_image(uploaded_file): | |
| """Load and preprocess an uploaded image for model inference. | |
| Raises | |
| ------ | |
| ValueError : if the file exceeds the size limit or is not a valid image. | |
| """ | |
| raw = uploaded_file.read() | |
| size_mb = len(raw) / 1e6 | |
| if size_mb > _MAX_FILE_SIZE_MB: | |
| raise ValueError( | |
| f"File is too large ({size_mb:.1f} MB). Maximum allowed size is {_MAX_FILE_SIZE_MB} MB." | |
| ) | |
| try: | |
| img = Image.open(io.BytesIO(raw)).convert("RGB").resize(IMAGE_SIZE) | |
| except UnidentifiedImageError as exc: | |
| raise ValueError("The uploaded file is not a valid image.") from exc | |
| # scale to [0, 1] floats, add batch axis → shape (1, 224, 224, 3) | |
| arr = np.expand_dims(np.array(img) / 255.0, axis=0).astype(np.float32) | |
| return img, arr | |
| def predict(model, img_array): | |
| """Run inference and return class probabilities with elapsed time.""" | |
| start = time.time() | |
| img_tensor = tf.convert_to_tensor(img_array, dtype=tf.float32) | |
| probs = model(img_tensor, training=False).numpy()[0] | |
| return probs, time.time() - start | |