Mariam-Samy's picture
Update app.py
89b15ef verified
# app.py
import numpy as np
import cv2
from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
import gradio as gr
# ===============================
# Device
# ===============================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===============================
# MNASNet Regressor
# ===============================
class MNASNetRegressor(nn.Module):
def __init__(self, num_outputs: int = 14, weights: str | None = "IMAGENET1K_V1"):
super().__init__()
if weights is None or (isinstance(weights, str) and weights.upper() == "NONE"):
mnas = models.mnasnet1_0(weights=None)
else:
enum = getattr(models, "MNASNet1_0_Weights")[weights]
mnas = models.mnasnet1_0(weights=enum)
num_features = mnas.classifier[1].in_features
mnas.classifier = nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(),
nn.Linear(128, num_outputs),
nn.Tanh(),
)
self.model = mnas
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
# ===============================
# Normalization constants
# ===============================
HEIGHT_MIN_CM = 152.3
HEIGHT_MAX_CM = 197.5
WEIGHT_MIN_KG = 46.1
WEIGHT_MAX_KG = 108.2
# 14 outputs: last one is height in mm (as in your setup)
Y_MIN_MM = np.array([
148.5, 398.2, 204.3, 252.1, 692.7, 211.6, 711.2, 702.9,
306.4, 703.2, 354.7, 602.5, 151.4, 1512.3
], dtype=np.float32)
Y_MAX_MM = np.array([
250.1, 802.8, 405.9, 452.3, 1198.4, 403.1, 1102.0, 1199.7,
502.5, 1203.9, 604.8, 999.6, 249.1, 2001.8
], dtype=np.float32)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
PREDICTION_NAMES = [
"ankle","arm-length","bicep","calf","chest",
"forearm","hip","leg-length","shoulder-breadth",
"shoulder-to-crotch","thigh","waist","wrist","height"
]
# GT inputs include ALL 14 now (including height) in cm
GT_NAMES = [
"ankle","arm-length","bicep","calf","chest",
"forearm","hip","leg-length","shoulder-breadth",
"shoulder-to-crotch","thigh","waist","wrist","height"
]
SINGLE_H, SINGLE_W = 640, 480
# ===============================
# Checkpoint path (optional)
# ===============================
# If you have checkpoint.py, it should contain:
# BEST_CHECKPOINT_PATH = "/content/drive/MyDrive/BMNet_Project/checkpoints_e-4/best_checkpoint.pth"
try:
from checkpoint import BEST_CHECKPOINT_PATH
CHECKPOINT_PATH = BEST_CHECKPOINT_PATH
except Exception:
CHECKPOINT_PATH = "best_checkpoint.pth"
# ===============================
# Load regression model
# ===============================
def load_regressor():
model = MNASNetRegressor(num_outputs=14, weights="NONE").to(DEVICE)
ckpt_loaded = False
try:
ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
if isinstance(ckpt, dict) and "state_dict" in ckpt:
state = ckpt["state_dict"]
elif isinstance(ckpt, dict) and "model_state_dict" in ckpt:
state = ckpt["model_state_dict"]
else:
state = ckpt
new_state = {}
for k, v in state.items():
new_k = k.replace("model.", "").replace("module.", "")
new_state[new_k] = v
model.load_state_dict(new_state, strict=False)
ckpt_loaded = True
print(f"Loaded regression checkpoint: {CHECKPOINT_PATH}")
except Exception as e:
print("WARNING could not load regression checkpoint:", e)
print("Set CHECKPOINT_PATH correctly (or checkpoint.py BEST_CHECKPOINT_PATH).")
model.eval()
return model, ckpt_loaded
regressor, _ = load_regressor()
# ===============================
# DeepLab for silhouette extraction
# ===============================
SEG_WEIGHTS = DeepLabV3_ResNet50_Weights.DEFAULT
seg_model = deeplabv3_resnet50(weights=SEG_WEIGHTS).to(DEVICE)
seg_model.eval()
seg_preprocess = SEG_WEIGHTS.transforms()
def extract_person_mask(pil_img: Image.Image) -> np.ndarray:
img_rgb = pil_img.convert("RGB")
inp = seg_preprocess(img_rgb).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = seg_model(inp)["out"]
pred = out.argmax(1)[0].cpu().numpy()
mask = (pred == 15).astype(np.uint8) * 255 # class 15 = person
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1)
return mask
def is_silhouette_like(pil_img: Image.Image) -> bool:
# heuristic: silhouettes usually have few grayscale levels
gray = np.array(pil_img.convert("L"))
u = np.unique(gray)
return len(u) <= 6 # slightly tolerant vs compression artifacts
def to_bodym_silhouette(pil_img: Image.Image, target_h=640, target_w=480) -> np.ndarray:
gray = np.array(pil_img.convert("L"))
if is_silhouette_like(pil_img):
_, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
else:
mask = extract_person_mask(pil_img)
pil_mask = Image.fromarray(mask)
pil_mask = pil_mask.resize((target_w, target_h), Image.NEAREST)
return np.array(pil_mask, dtype=np.uint8)
# ===============================
# Input building for MNASNet
# ===============================
def build_input_tensor(front_sil, side_sil, height_cm, weight_kg):
front_f = front_sil.astype(np.float32) / 255.0
side_f = side_sil.astype(np.float32) / 255.0
concat_img = np.concatenate([front_f, side_f], axis=1).astype(np.float32)
h_norm = (height_cm - HEIGHT_MIN_CM) / (HEIGHT_MAX_CM - HEIGHT_MIN_CM)
w_norm = (weight_kg - WEIGHT_MIN_KG) / (WEIGHT_MAX_KG - WEIGHT_MIN_KG)
h_norm = float(np.clip(h_norm, 0, 1))
w_norm = float(np.clip(w_norm, 0, 1))
height_map = np.full_like(concat_img, h_norm, dtype=np.float32)
weight_map = np.full_like(concat_img, w_norm, dtype=np.float32)
stacked = np.stack([concat_img, height_map, weight_map], axis=0)
stacked = (stacked - IMAGENET_MEAN[:, None, None]) / IMAGENET_STD[:, None, None]
return torch.from_numpy(stacked).float().unsqueeze(0).to(DEVICE)
# ===============================
# Denormalize predictions
# ===============================
def denormalize_predictions(pred_norm: np.ndarray) -> np.ndarray:
return 0.5 * (pred_norm + 1.0) * (Y_MAX_MM - Y_MIN_MM) + Y_MIN_MM
# ===============================
# Predict
# ===============================
def predict(front_img, side_img, height_cm, weight_kg, *gt_fields):
if front_img is None or side_img is None:
return "Please upload both images.", [], None
height_cm = float(height_cm)
weight_kg = float(weight_kg)
front_sil = to_bodym_silhouette(front_img, target_h=SINGLE_H, target_w=SINGLE_W)
side_sil = to_bodym_silhouette(side_img, target_h=SINGLE_H, target_w=SINGLE_W)
x = build_input_tensor(front_sil, side_sil, height_cm, weight_kg)
with torch.no_grad():
pred_norm = regressor(x)[0].cpu().numpy()
y_mm = denormalize_predictions(pred_norm).astype(np.float32)
y_cm = (y_mm / 10.0).astype(np.float32)
# Parse GT (cm) for ALL 14 now (including height)
gt_vals_cm = []
for field in gt_fields:
if field is None:
gt_vals_cm.append(None)
continue
s = str(field).strip()
if s == "":
gt_vals_cm.append(None)
continue
try:
gt_vals_cm.append(float(s))
except:
gt_vals_cm.append(None)
rows = []
# rows: [Measurement, Predicted(cm), Ground truth(cm), Abs error(cm)]
for i, name in enumerate(PREDICTION_NAMES):
pred_cm = float(np.round(y_cm[i], 2))
gt_cm = gt_vals_cm[i] if i < len(gt_vals_cm) else None
if gt_cm is None:
rows.append([name, pred_cm, "", ""])
else:
err_cm = float(np.round(abs(pred_cm - float(gt_cm)), 2))
rows.append([name, pred_cm, float(np.round(gt_cm, 2)), err_cm])
# ---- Plot built DIRECTLY from rows to avoid label mismatch ----
fig = None
plot_names = []
plot_errs = []
for r in rows:
name = r[0]
err = r[3]
if err == "" or err is None:
continue
plot_names.append(str(name))
plot_errs.append(float(err))
if len(plot_errs) > 0:
plot_errs = np.array(plot_errs, dtype=float)
plot_names = np.array(plot_names, dtype=str)
order = np.argsort(plot_errs) # sorted (largest last)
plot_errs = plot_errs[order]
plot_names = plot_names[order]
fig, ax = plt.subplots(figsize=(9, 4))
ax.barh(plot_names, plot_errs)
ax.set_xlabel("Absolute Error (cm)")
ax.set_title("Prediction Error per Measurement (sorted)")
ax.grid(True, axis="x", alpha=0.3)
fig.tight_layout()
return "Prediction completed.", rows, fig
# ===============================
# UI (no css=... to avoid your gradio error)
# ===============================
with gr.Blocks() as demo:
gr.Markdown("# Human Body Measurement Predictor (MNASNet + BodyM)")
gr.Markdown("Upload **RGB photos or silhouettes**. The app detects the type automatically.")
with gr.Row():
front_in = gr.Image(label="Front View", type="pil", height=260)
side_in = gr.Image(label="Side View", type="pil", height=260)
with gr.Row():
height_in = gr.Number(label="Height input (cm)", value=170)
weight_in = gr.Number(label="Weight input (kg)", value=70)
gr.Markdown("### Optional: Ground Truth Body Measurements (cm)")
gr.Markdown("Fill any of the following to compute absolute error (in **cm**).")
gt_inputs = []
with gr.Row():
with gr.Column():
for name in GT_NAMES[:7]:
gt_inputs.append(gr.Textbox(label=f"{name} (cm)", placeholder="Optional"))
with gr.Column():
for name in GT_NAMES[7:]:
gt_inputs.append(gr.Textbox(label=f"{name} (cm)", placeholder="Optional"))
run_btn = gr.Button("Predict")
status_out = gr.Markdown()
result_table = gr.Dataframe(
headers=["Measurement", "Predicted (cm)", "Ground truth (cm)", "Abs error (cm)"],
label="Predicted vs Ground Truth",
interactive=False,
)
error_plot = gr.Plot(label="Absolute Error (cm)")
run_btn.click(
fn=predict,
inputs=[front_in, side_in, height_in, weight_in] + gt_inputs,
outputs=[status_out, result_table, error_plot],
)
demo.launch()