Spaces:
Runtime error
Runtime error
| # app.py | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights | |
| import gradio as gr | |
| # =============================== | |
| # Device | |
| # =============================== | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # =============================== | |
| # MNASNet Regressor | |
| # =============================== | |
| class MNASNetRegressor(nn.Module): | |
| def __init__(self, num_outputs: int = 14, weights: str | None = "IMAGENET1K_V1"): | |
| super().__init__() | |
| if weights is None or (isinstance(weights, str) and weights.upper() == "NONE"): | |
| mnas = models.mnasnet1_0(weights=None) | |
| else: | |
| enum = getattr(models, "MNASNet1_0_Weights")[weights] | |
| mnas = models.mnasnet1_0(weights=enum) | |
| num_features = mnas.classifier[1].in_features | |
| mnas.classifier = nn.Sequential( | |
| nn.Linear(num_features, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, num_outputs), | |
| nn.Tanh(), | |
| ) | |
| self.model = mnas | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.model(x) | |
| # =============================== | |
| # Normalization constants | |
| # =============================== | |
| HEIGHT_MIN_CM = 152.3 | |
| HEIGHT_MAX_CM = 197.5 | |
| WEIGHT_MIN_KG = 46.1 | |
| WEIGHT_MAX_KG = 108.2 | |
| # 14 outputs: last one is height in mm (as in your setup) | |
| Y_MIN_MM = np.array([ | |
| 148.5, 398.2, 204.3, 252.1, 692.7, 211.6, 711.2, 702.9, | |
| 306.4, 703.2, 354.7, 602.5, 151.4, 1512.3 | |
| ], dtype=np.float32) | |
| Y_MAX_MM = np.array([ | |
| 250.1, 802.8, 405.9, 452.3, 1198.4, 403.1, 1102.0, 1199.7, | |
| 502.5, 1203.9, 604.8, 999.6, 249.1, 2001.8 | |
| ], dtype=np.float32) | |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| PREDICTION_NAMES = [ | |
| "ankle","arm-length","bicep","calf","chest", | |
| "forearm","hip","leg-length","shoulder-breadth", | |
| "shoulder-to-crotch","thigh","waist","wrist","height" | |
| ] | |
| # GT inputs include ALL 14 now (including height) in cm | |
| GT_NAMES = [ | |
| "ankle","arm-length","bicep","calf","chest", | |
| "forearm","hip","leg-length","shoulder-breadth", | |
| "shoulder-to-crotch","thigh","waist","wrist","height" | |
| ] | |
| SINGLE_H, SINGLE_W = 640, 480 | |
| # =============================== | |
| # Checkpoint path (optional) | |
| # =============================== | |
| # If you have checkpoint.py, it should contain: | |
| # BEST_CHECKPOINT_PATH = "/content/drive/MyDrive/BMNet_Project/checkpoints_e-4/best_checkpoint.pth" | |
| try: | |
| from checkpoint import BEST_CHECKPOINT_PATH | |
| CHECKPOINT_PATH = BEST_CHECKPOINT_PATH | |
| except Exception: | |
| CHECKPOINT_PATH = "best_checkpoint.pth" | |
| # =============================== | |
| # Load regression model | |
| # =============================== | |
| def load_regressor(): | |
| model = MNASNetRegressor(num_outputs=14, weights="NONE").to(DEVICE) | |
| ckpt_loaded = False | |
| try: | |
| ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| if isinstance(ckpt, dict) and "state_dict" in ckpt: | |
| state = ckpt["state_dict"] | |
| elif isinstance(ckpt, dict) and "model_state_dict" in ckpt: | |
| state = ckpt["model_state_dict"] | |
| else: | |
| state = ckpt | |
| new_state = {} | |
| for k, v in state.items(): | |
| new_k = k.replace("model.", "").replace("module.", "") | |
| new_state[new_k] = v | |
| model.load_state_dict(new_state, strict=False) | |
| ckpt_loaded = True | |
| print(f"Loaded regression checkpoint: {CHECKPOINT_PATH}") | |
| except Exception as e: | |
| print("WARNING could not load regression checkpoint:", e) | |
| print("Set CHECKPOINT_PATH correctly (or checkpoint.py BEST_CHECKPOINT_PATH).") | |
| model.eval() | |
| return model, ckpt_loaded | |
| regressor, _ = load_regressor() | |
| # =============================== | |
| # DeepLab for silhouette extraction | |
| # =============================== | |
| SEG_WEIGHTS = DeepLabV3_ResNet50_Weights.DEFAULT | |
| seg_model = deeplabv3_resnet50(weights=SEG_WEIGHTS).to(DEVICE) | |
| seg_model.eval() | |
| seg_preprocess = SEG_WEIGHTS.transforms() | |
| def extract_person_mask(pil_img: Image.Image) -> np.ndarray: | |
| img_rgb = pil_img.convert("RGB") | |
| inp = seg_preprocess(img_rgb).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| out = seg_model(inp)["out"] | |
| pred = out.argmax(1)[0].cpu().numpy() | |
| mask = (pred == 15).astype(np.uint8) * 255 # class 15 = person | |
| kernel = np.ones((5, 5), np.uint8) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=1) | |
| return mask | |
| def is_silhouette_like(pil_img: Image.Image) -> bool: | |
| # heuristic: silhouettes usually have few grayscale levels | |
| gray = np.array(pil_img.convert("L")) | |
| u = np.unique(gray) | |
| return len(u) <= 6 # slightly tolerant vs compression artifacts | |
| def to_bodym_silhouette(pil_img: Image.Image, target_h=640, target_w=480) -> np.ndarray: | |
| gray = np.array(pil_img.convert("L")) | |
| if is_silhouette_like(pil_img): | |
| _, mask = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY) | |
| else: | |
| mask = extract_person_mask(pil_img) | |
| pil_mask = Image.fromarray(mask) | |
| pil_mask = pil_mask.resize((target_w, target_h), Image.NEAREST) | |
| return np.array(pil_mask, dtype=np.uint8) | |
| # =============================== | |
| # Input building for MNASNet | |
| # =============================== | |
| def build_input_tensor(front_sil, side_sil, height_cm, weight_kg): | |
| front_f = front_sil.astype(np.float32) / 255.0 | |
| side_f = side_sil.astype(np.float32) / 255.0 | |
| concat_img = np.concatenate([front_f, side_f], axis=1).astype(np.float32) | |
| h_norm = (height_cm - HEIGHT_MIN_CM) / (HEIGHT_MAX_CM - HEIGHT_MIN_CM) | |
| w_norm = (weight_kg - WEIGHT_MIN_KG) / (WEIGHT_MAX_KG - WEIGHT_MIN_KG) | |
| h_norm = float(np.clip(h_norm, 0, 1)) | |
| w_norm = float(np.clip(w_norm, 0, 1)) | |
| height_map = np.full_like(concat_img, h_norm, dtype=np.float32) | |
| weight_map = np.full_like(concat_img, w_norm, dtype=np.float32) | |
| stacked = np.stack([concat_img, height_map, weight_map], axis=0) | |
| stacked = (stacked - IMAGENET_MEAN[:, None, None]) / IMAGENET_STD[:, None, None] | |
| return torch.from_numpy(stacked).float().unsqueeze(0).to(DEVICE) | |
| # =============================== | |
| # Denormalize predictions | |
| # =============================== | |
| def denormalize_predictions(pred_norm: np.ndarray) -> np.ndarray: | |
| return 0.5 * (pred_norm + 1.0) * (Y_MAX_MM - Y_MIN_MM) + Y_MIN_MM | |
| # =============================== | |
| # Predict | |
| # =============================== | |
| def predict(front_img, side_img, height_cm, weight_kg, *gt_fields): | |
| if front_img is None or side_img is None: | |
| return "Please upload both images.", [], None | |
| height_cm = float(height_cm) | |
| weight_kg = float(weight_kg) | |
| front_sil = to_bodym_silhouette(front_img, target_h=SINGLE_H, target_w=SINGLE_W) | |
| side_sil = to_bodym_silhouette(side_img, target_h=SINGLE_H, target_w=SINGLE_W) | |
| x = build_input_tensor(front_sil, side_sil, height_cm, weight_kg) | |
| with torch.no_grad(): | |
| pred_norm = regressor(x)[0].cpu().numpy() | |
| y_mm = denormalize_predictions(pred_norm).astype(np.float32) | |
| y_cm = (y_mm / 10.0).astype(np.float32) | |
| # Parse GT (cm) for ALL 14 now (including height) | |
| gt_vals_cm = [] | |
| for field in gt_fields: | |
| if field is None: | |
| gt_vals_cm.append(None) | |
| continue | |
| s = str(field).strip() | |
| if s == "": | |
| gt_vals_cm.append(None) | |
| continue | |
| try: | |
| gt_vals_cm.append(float(s)) | |
| except: | |
| gt_vals_cm.append(None) | |
| rows = [] | |
| # rows: [Measurement, Predicted(cm), Ground truth(cm), Abs error(cm)] | |
| for i, name in enumerate(PREDICTION_NAMES): | |
| pred_cm = float(np.round(y_cm[i], 2)) | |
| gt_cm = gt_vals_cm[i] if i < len(gt_vals_cm) else None | |
| if gt_cm is None: | |
| rows.append([name, pred_cm, "", ""]) | |
| else: | |
| err_cm = float(np.round(abs(pred_cm - float(gt_cm)), 2)) | |
| rows.append([name, pred_cm, float(np.round(gt_cm, 2)), err_cm]) | |
| # ---- Plot built DIRECTLY from rows to avoid label mismatch ---- | |
| fig = None | |
| plot_names = [] | |
| plot_errs = [] | |
| for r in rows: | |
| name = r[0] | |
| err = r[3] | |
| if err == "" or err is None: | |
| continue | |
| plot_names.append(str(name)) | |
| plot_errs.append(float(err)) | |
| if len(plot_errs) > 0: | |
| plot_errs = np.array(plot_errs, dtype=float) | |
| plot_names = np.array(plot_names, dtype=str) | |
| order = np.argsort(plot_errs) # sorted (largest last) | |
| plot_errs = plot_errs[order] | |
| plot_names = plot_names[order] | |
| fig, ax = plt.subplots(figsize=(9, 4)) | |
| ax.barh(plot_names, plot_errs) | |
| ax.set_xlabel("Absolute Error (cm)") | |
| ax.set_title("Prediction Error per Measurement (sorted)") | |
| ax.grid(True, axis="x", alpha=0.3) | |
| fig.tight_layout() | |
| return "Prediction completed.", rows, fig | |
| # =============================== | |
| # UI (no css=... to avoid your gradio error) | |
| # =============================== | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Human Body Measurement Predictor (MNASNet + BodyM)") | |
| gr.Markdown("Upload **RGB photos or silhouettes**. The app detects the type automatically.") | |
| with gr.Row(): | |
| front_in = gr.Image(label="Front View", type="pil", height=260) | |
| side_in = gr.Image(label="Side View", type="pil", height=260) | |
| with gr.Row(): | |
| height_in = gr.Number(label="Height input (cm)", value=170) | |
| weight_in = gr.Number(label="Weight input (kg)", value=70) | |
| gr.Markdown("### Optional: Ground Truth Body Measurements (cm)") | |
| gr.Markdown("Fill any of the following to compute absolute error (in **cm**).") | |
| gt_inputs = [] | |
| with gr.Row(): | |
| with gr.Column(): | |
| for name in GT_NAMES[:7]: | |
| gt_inputs.append(gr.Textbox(label=f"{name} (cm)", placeholder="Optional")) | |
| with gr.Column(): | |
| for name in GT_NAMES[7:]: | |
| gt_inputs.append(gr.Textbox(label=f"{name} (cm)", placeholder="Optional")) | |
| run_btn = gr.Button("Predict") | |
| status_out = gr.Markdown() | |
| result_table = gr.Dataframe( | |
| headers=["Measurement", "Predicted (cm)", "Ground truth (cm)", "Abs error (cm)"], | |
| label="Predicted vs Ground Truth", | |
| interactive=False, | |
| ) | |
| error_plot = gr.Plot(label="Absolute Error (cm)") | |
| run_btn.click( | |
| fn=predict, | |
| inputs=[front_in, side_in, height_in, weight_in] + gt_inputs, | |
| outputs=[status_out, result_table, error_plot], | |
| ) | |
| demo.launch() | |