dimension_model / src /streamlit_app.py
suryaprakash01's picture
Update src/streamlit_app.py
8f4c685 verified
raw
history blame
3.1 kB
import streamlit as st
import torch
import torchvision.transforms as T
import numpy as np
import cv2
from PIL import Image, UnidentifiedImageError
import os
# Detectron2
try:
import detectron2
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer, ColorMode
from detectron2.data import MetadataCatalog
except ImportError:
st.error("Detectron2 not found. Please ensure it's in requirements.txt and Hugging Face supports it.")
# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CNN_INPUT_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
MODEL_PATH = "model/pix3d_dimension_estimator_mask_crop.pth"
# Dimension Estimation CNN
from torchvision import models as torchvision_models
import torch.nn as nn
def create_dimension_estimator_cnn_for_inference(num_outputs=4):
model = torchvision_models.resnet50(weights=None)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_ftrs, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_outputs)
)
return model
@st.cache_resource
def load_dimension_model():
if not os.path.exists(MODEL_PATH):
st.error(f"Dimension model not found at {MODEL_PATH}. Add your .pth file to the model/ folder.")
return None
model = create_dimension_estimator_cnn_for_inference()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
return model
@st.cache_resource
def load_detectron2_model():
try:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.DEVICE = "cpu" if DEVICE.type == "cpu" else "cuda"
predictor = DefaultPredictor(cfg)
metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0] if cfg.DATASETS.TRAIN else "coco_2017_val")
return predictor, metadata
except Exception as e:
st.error(f"Detectron2 load error: {e}")
return None, None
# Helper functions (get_largest_instance_index, crop_from_mask, predict_dimensions_cnn) go here
# For brevity, they are omitted in this snippet. Copy them from your existing code.
st.set_page_config(layout="wide", page_title="Object Dimension Estimator")
st.title("Object Dimension & Volume Estimation")
dim_model = load_dimension_model()
d2_predictor, d2_metadata = load_detectron2_model()
uploaded_file = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png'])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
image_np = np.array(image)
outputs = d2_predictor(cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
instances = outputs["instances"].to("cpu")
# Process and display results as in your original code