dimension_model / app.py
suryaprakash01's picture
Update app.py
b102658 verified
#!/usr/bin/env python3
import os
# ────────────────────────────────────────────────────────────
# ENV / CACHE SETUP (Hugging Face safe)
# ────────────────────────────────────────────────────────────
os.environ["STREAMLIT_BROWSER_GATHER_USAGE_STATS"] = "false"
os.environ["STREAMLIT_DISABLE_WARNINGS"] = "true"
os.environ["HOME"] = "/tmp"
streamlit_config = os.path.join(os.environ["HOME"], ".streamlit")
os.makedirs(streamlit_config, exist_ok=True)
os.environ["STREAMLIT_CONFIG_DIR"] = streamlit_config
os.environ["MPLCONFIGDIR"] = os.path.join(os.environ["HOME"], ".matplotlib")
os.environ["XDG_CACHE_HOME"] = os.path.join(os.environ["HOME"], ".cache")
# ────────────────────────────────────────────────────────────
# IMPORTS
# ────────────────────────────────────────────────────────────
import streamlit as st
import torch
import torchvision.transforms as T
from torchvision.models.detection import maskrcnn_resnet50_fpn
import numpy as np
import cv2
from PIL import Image
import io
import torch.nn as nn
from torchvision import models as torchvision_models
# ────────────────────────────────────────────────────────────
# GLOBAL CONFIG
# ────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "model/pix3d_dimension_estimator_mask_crop.pth"
CNN_INPUT_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
DETECTION_SCORE_THRESH = 0.5
# ────────────────────────────────────────────────────────────
# DIMENSION ESTIMATOR CNN
# ────────────────────────────────────────────────────────────
def create_dimension_estimator_cnn_for_inference(num_outputs=4):
base = torchvision_models.resnet50(weights=None)
num_ftrs = base.fc.in_features
base.fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_outputs)
)
return base
@st.cache_resource
def load_dimension_model():
if not os.path.exists(MODEL_PATH):
st.error(f"Dimension model not found at {MODEL_PATH}")
return None
try:
model = create_dimension_estimator_cnn_for_inference()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("Dimension model loaded.")
return model
except Exception as e:
st.error(f"Failed to load dimension model: {e}")
return None
# ────────────────────────────────────────────────────────────
# OBJECT DETECTION MODEL (Mask R-CNN)
# ────────────────────────────────────────────────────────────
@st.cache_resource
def load_detection_model():
try:
model = maskrcnn_resnet50_fpn(pretrained=True)
model.to(DEVICE)
model.eval()
print("Mask R-CNN loaded.")
return model
except Exception as e:
st.error(f"Error loading Mask R-CNN: {e}")
return None
# ────────────────────────────────────────────────────────────
# UTILS
# ────────────────────────────────────────────────────────────
def get_largest_instance_index_from_masks(masks, scores, score_thresh=0.5):
if masks is None or len(masks) == 0:
return -1
scores = scores.cpu().numpy()
masks_np = masks.cpu().numpy()
valid_indices = [i for i, s in enumerate(scores) if s >= score_thresh]
if not valid_indices:
return -1
areas = [(masks_np[i, 0] > 0.5).sum() for i in valid_indices]
return valid_indices[int(np.argmax(areas))]
def crop_from_binary_mask(img_rgb, mask_np, pad=5):
if mask_np.sum() == 0:
return None
mask_np = mask_np.astype(np.uint8)
rows = np.any(mask_np, axis=1)
cols = np.any(mask_np, axis=0)
ymin, ymax = np.where(rows)[0][[0, -1]]
xmin, xmax = np.where(cols)[0][[0, -1]]
ymin = max(0, ymin - pad)
xmin = max(0, xmin - pad)
ymax = min(img_rgb.shape[0] - 1, ymax + pad)
xmax = min(img_rgb.shape[1] - 1, xmax + pad)
if ymin >= ymax or xmin >= xmax:
return None
return img_rgb[ymin:ymax+1, xmin:xmax+1]
def draw_detections(img_rgb, boxes, scores, thresh=0.5):
vis = img_rgb.copy()
boxes = boxes.cpu().numpy()
scores = scores.cpu().numpy()
for box, score in zip(boxes, scores):
if score < thresh:
continue
x1, y1, x2, y2 = box.astype(int)
cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
return vis
def predict_dimensions_cnn(img_rgb, model):
transform = T.Compose([
T.ToPILImage(),
T.Resize((CNN_INPUT_SIZE, CNN_INPUT_SIZE)),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
try:
inp = transform(img_rgb).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = model(inp).squeeze().cpu().tolist()
while len(out) < 4:
out.append(0.0)
L, W, H, V = out
return {
"Length (cm)": f"{L*100:.1f}",
"Width (cm)": f"{W*100:.1f}",
"Height (cm)": f"{H*100:.1f}",
"Volume (cmΒ³)": f"{V*1e6:.1f}"
}
except:
return {"Error": "CNN Prediction Failed"}
# ────────────────────────────────────────────────────────────
# STREAMLIT UI
# ────────────────────────────────────────────────────────────
st.set_page_config(layout="wide", page_title="Object Dimension Estimator")
st.title("Object Dimension & Volume Estimation")
dim_model = load_dimension_model()
det_model = load_detection_model()
st.subheader("Upload an image")
uploaded = st.file_uploader("Choose JPG/PNG", type=["jpg", "jpeg", "png"])
if uploaded:
try:
# IMPORTANT: Do NOT use uploaded.read() directly.
# Instead, we manually convert to bytes.
raw_bytes = uploaded.getvalue() # <-- THIS bypasses HF upload pipeline
img = Image.open(io.BytesIO(raw_bytes)).convert("RGB")
img_np = np.array(img)
st.image(img, caption="Uploaded Image", use_column_width=True)
if dim_model and det_model:
with st.spinner("Detecting & Estimating..."):
img_tensor = T.ToTensor()(img).to(DEVICE)
outputs = det_model([img_tensor])[0]
boxes = outputs["boxes"]
scores = outputs["scores"]
masks = outputs.get("masks")
if len(boxes) == 0:
st.warning("No objects detected.")
else:
det_vis = draw_detections(img_np, boxes, scores, DETECTION_SCORE_THRESH)
st.image(det_vis, caption="Detected Objects", use_column_width=True)
idx = get_largest_instance_index_from_masks(masks, scores, DETECTION_SCORE_THRESH)
if idx >= 0:
mask_np = (masks[idx, 0].cpu().numpy() > 0.5).astype(np.uint8)
crop = crop_from_binary_mask(img_np, mask_np)
if crop is not None:
st.image(crop, caption="Cropped Object", width=250)
dims = predict_dimensions_cnn(crop, dim_model)
st.json(dims)
else:
st.error("Cropping failed.")
else:
st.warning("No high-confidence instance found.")
except Exception as e:
st.error(f"Error processing image: {e}")
# Sidebar device info
st.sidebar.markdown("---")
st.sidebar.write(f"Device: {DEVICE}")
st.sidebar.write(f"Detector: {'OK' if det_model else 'Failed'}")
st.sidebar.write(f"Dimension Model: {'OK' if dim_model else 'Failed'}")