LUCIAN / src /model.py
github-actions[bot]
Sync from GitHub (89cd667b817f5a97f6ec0ff6ce33fd08bc24475b)
abb4211
import io
import logging
import os
import time
import numpy as np
import streamlit as st
import tensorflow as tf
from huggingface_hub import hf_hub_download
from PIL import Image, UnidentifiedImageError
from src.config import HF_FILENAME, HF_REPO_ID, IMAGE_SIZE, MODEL_PATH
logger = logging.getLogger(__name__)
_MAX_FILE_SIZE_MB = 10
@st.cache_resource(show_spinner=False)
def load_model():
"""Load the ConvNeXt model, downloading from HuggingFace if not cached locally."""
# size == 0 catches a previous aborted download that left a 0-byte file
if not os.path.exists(MODEL_PATH) or os.path.getsize(MODEL_PATH) == 0:
logger.info("Model not found locally — downloading from HuggingFace Hub.")
_download_model()
start = time.time()
logger.info("Loading model from %s", MODEL_PATH)
model = tf.keras.models.load_model(MODEL_PATH)
elapsed = time.time() - start
logger.info("Model loaded in %.2fs", elapsed)
return model, elapsed
def _download_model() -> None:
"""Download the model from HuggingFace Hub using the official SDK.
hf_hub_download handles resumable downloads, caching, and progress
reporting automatically — replacing the manual requests.get() approach.
local_dir pins the file to the project folder instead of ~/.cache/huggingface,
so the app finds it at MODEL_PATH on all platforms including Docker.
"""
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
logger.info("Downloading %s from repo %s", HF_FILENAME, HF_REPO_ID)
hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_FILENAME,
local_dir=os.path.dirname(MODEL_PATH),
)
logger.info("Model saved to %s", MODEL_PATH)
def preprocess_image(uploaded_file):
"""Load and preprocess an uploaded image for model inference.
Raises
------
ValueError : if the file exceeds the size limit or is not a valid image.
"""
raw = uploaded_file.read()
size_mb = len(raw) / 1e6
if size_mb > _MAX_FILE_SIZE_MB:
raise ValueError(
f"File is too large ({size_mb:.1f} MB). Maximum allowed size is {_MAX_FILE_SIZE_MB} MB."
)
try:
img = Image.open(io.BytesIO(raw)).convert("RGB").resize(IMAGE_SIZE)
except UnidentifiedImageError as exc:
raise ValueError("The uploaded file is not a valid image.") from exc
# scale to [0, 1] floats, add batch axis → shape (1, 224, 224, 3)
arr = np.expand_dims(np.array(img) / 255.0, axis=0).astype(np.float32)
return img, arr
def predict(model, img_array):
"""Run inference and return class probabilities with elapsed time."""
start = time.time()
img_tensor = tf.convert_to_tensor(img_array, dtype=tf.float32)
probs = model(img_tensor, training=False).numpy()[0]
return probs, time.time() - start