#!/usr/bin/env python3 import os # ──────────────────────────────────────────────────────────── # ENV / CACHE SETUP (Hugging Face safe) # ──────────────────────────────────────────────────────────── os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false" os.environ["STREAMLIT_DISABLE_WARNINGS"] = "true" os.environ["HOME"] = "/tmp" streamlit_config = os.path.join(os.environ["HOME"], ".streamlit") os.makedirs(streamlit_config, exist_ok=True) os.environ["STREAMLIT_CONFIG_DIR"] = streamlit_config os.environ["MPLCONFIGDIR"] = os.path.join(os.environ["HOME"], ".matplotlib") os.environ["XDG_CACHE_HOME"] = os.path.join(os.environ["HOME"], ".cache") # ──────────────────────────────────────────────────────────── # IMPORTS # ──────────────────────────────────────────────────────────── import streamlit as st import torch import torchvision.transforms as T from torchvision.models.detection import maskrcnn_resnet50_fpn import numpy as np import cv2 from PIL import Image import io import torch.nn as nn from torchvision import models as torchvision_models # ──────────────────────────────────────────────────────────── # GLOBAL CONFIG # ──────────────────────────────────────────────────────────── DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_PATH = "model/pix3d_dimension_estimator_mask_crop.pth" CNN_INPUT_SIZE = 224 IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] DETECTION_SCORE_THRESH = 0.5 # ──────────────────────────────────────────────────────────── # DIMENSION ESTIMATOR CNN # ──────────────────────────────────────────────────────────── def create_dimension_estimator_cnn_for_inference(num_outputs=4): base = torchvision_models.resnet50(weights=None) num_ftrs = base.fc.in_features base.fc = nn.Sequential( nn.Linear(num_ftrs, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_outputs) ) return base @st.cache_resource def load_dimension_model(): if not os.path.exists(MODEL_PATH): st.error(f"Dimension model not found at {MODEL_PATH}") return None try: model = create_dimension_estimator_cnn_for_inference() model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.to(DEVICE) model.eval() print("Dimension model loaded.") return model except Exception as e: st.error(f"Failed to load dimension model: {e}") return None # ──────────────────────────────────────────────────────────── # OBJECT DETECTION MODEL (Mask R-CNN) # ──────────────────────────────────────────────────────────── @st.cache_resource def load_detection_model(): try: model = maskrcnn_resnet50_fpn(pretrained=True) model.to(DEVICE) model.eval() print("Mask R-CNN loaded.") return model except Exception as e: st.error(f"Error loading Mask R-CNN: {e}") return None # ──────────────────────────────────────────────────────────── # UTILS # ──────────────────────────────────────────────────────────── def get_largest_instance_index_from_masks(masks, scores, score_thresh=0.5): if masks is None or len(masks) == 0: return -1 scores = scores.cpu().numpy() masks_np = masks.cpu().numpy() valid_indices = [i for i, s in enumerate(scores) if s >= score_thresh] if not valid_indices: return -1 areas = [(masks_np[i, 0] > 0.5).sum() for i in valid_indices] return valid_indices[int(np.argmax(areas))] def crop_from_binary_mask(img_rgb, mask_np, pad=5): if mask_np.sum() == 0: return None mask_np = mask_np.astype(np.uint8) rows = np.any(mask_np, axis=1) cols = np.any(mask_np, axis=0) ymin, ymax = np.where(rows)[0][[0, -1]] xmin, xmax = np.where(cols)[0][[0, -1]] ymin = max(0, ymin - pad) xmin = max(0, xmin - pad) ymax = min(img_rgb.shape[0] - 1, ymax + pad) xmax = min(img_rgb.shape[1] - 1, xmax + pad) if ymin >= ymax or xmin >= xmax: return None return img_rgb[ymin:ymax+1, xmin:xmax+1] def draw_detections(img_rgb, boxes, scores, thresh=0.5): vis = img_rgb.copy() boxes = boxes.cpu().numpy() scores = scores.cpu().numpy() for box, score in zip(boxes, scores): if score < thresh: continue x1, y1, x2, y2 = box.astype(int) cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2) return vis def predict_dimensions_cnn(img_rgb, model): transform = T.Compose([ T.ToPILImage(), T.Resize((CNN_INPUT_SIZE, CNN_INPUT_SIZE)), T.ToTensor(), T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) ]) try: inp = transform(img_rgb).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = model(inp).squeeze().cpu().tolist() while len(out) < 4: out.append(0.0) L, W, H, V = out return { "Length (cm)": f"{L*100:.1f}", "Width (cm)": f"{W*100:.1f}", "Height (cm)": f"{H*100:.1f}", "Volume (cm³)": f"{V*1e6:.1f}" } except: return {"Error": "CNN Prediction Failed"} # ──────────────────────────────────────────────────────────── # STREAMLIT UI # ──────────────────────────────────────────────────────────── st.set_page_config(layout="wide", page_title="Object Dimension Estimator") st.title("Object Dimension & Volume Estimation") dim_model = load_dimension_model() det_model = load_detection_model() st.subheader("Upload an image") uploaded = st.file_uploader("Choose JPG/PNG", type=["jpg", "jpeg", "png"]) if uploaded: try: # IMPORTANT: Do NOT use uploaded.read() directly. # Instead, we manually convert to bytes. raw_bytes = uploaded.getvalue() # <-- THIS bypasses HF upload pipeline img = Image.open(io.BytesIO(raw_bytes)).convert("RGB") img_np = np.array(img) st.image(img, caption="Uploaded Image", use_column_width=True) if dim_model and det_model: with st.spinner("Detecting & Estimating..."): img_tensor = T.ToTensor()(img).to(DEVICE) outputs = det_model([img_tensor])[0] boxes = outputs["boxes"] scores = outputs["scores"] masks = outputs.get("masks") if len(boxes) == 0: st.warning("No objects detected.") else: det_vis = draw_detections(img_np, boxes, scores, DETECTION_SCORE_THRESH) st.image(det_vis, caption="Detected Objects", use_column_width=True) idx = get_largest_instance_index_from_masks(masks, scores, DETECTION_SCORE_THRESH) if idx >= 0: mask_np = (masks[idx, 0].cpu().numpy() > 0.5).astype(np.uint8) crop = crop_from_binary_mask(img_np, mask_np) if crop is not None: st.image(crop, caption="Cropped Object", width=250) dims = predict_dimensions_cnn(crop, dim_model) st.json(dims) else: st.error("Cropping failed.") else: st.warning("No high-confidence instance found.") except Exception as e: st.error(f"Error processing image: {e}") # Sidebar device info st.sidebar.markdown("---") st.sidebar.write(f"Device: {DEVICE}") st.sidebar.write(f"Detector: {'OK' if det_model else 'Failed'}") st.sidebar.write(f"Dimension Model: {'OK' if dim_model else 'Failed'}")