import json import os import sys from functools import lru_cache from pathlib import Path from typing import List, Tuple import gradio as gr import torch from PIL import Image, ImageDraw SPACE_DIR = Path(__file__).resolve().parent PROJECT_DIR = SPACE_DIR.parent for path in (SPACE_DIR, PROJECT_DIR): if str(path) not in sys.path: sys.path.insert(0, str(path)) try: from distillation.common import ( ACTIONS, bbox_cxcywh_to_xyxy, box_state, clamp_xywh, load_teacher, render_crop, render_full_image, step_box, ) except ModuleNotFoundError as exc: raise ModuleNotFoundError( "Cannot import distillation.common. Deploy this demo together with the " "Adacrop/distillation directory, or copy distillation/common.py into the Space repo." ) from exc IMG_SIZE = int(os.getenv("IMG_SIZE", "224")) ACTION_DELTA = float(os.getenv("ACTION_DELTA", "0.05")) DEFAULT_MAX_STEPS = int(os.getenv("DEFAULT_MAX_STEPS", "60")) MODEL_ENV = os.getenv("MODEL_PATH", "ppo_best_val_final_score.pth") def resolve_model_path() -> Path: raw = Path(MODEL_ENV) candidates = [] if raw.is_absolute(): candidates.append(raw) candidates.extend( [ SPACE_DIR / raw, PROJECT_DIR / raw, SPACE_DIR / "models" / raw.name, PROJECT_DIR / "models" / raw.name, ] ) for candidate in candidates: if candidate.exists(): return candidate checked = "\n".join(str(p) for p in candidates) raise FileNotFoundError( f"Could not find model checkpoint {MODEL_ENV!r}. Checked:\n{checked}\n" "Put ppo_best_val_final_score.pth in the Space root, or set MODEL_PATH." ) def get_device() -> torch.device: if os.getenv("FORCE_CPU", "0") == "1": return torch.device("cpu") return torch.device("cuda" if torch.cuda.is_available() else "cpu") @lru_cache(maxsize=1) def get_model(): if os.getenv("DISABLE_CUDNN", "0") == "1": torch.backends.cudnn.enabled = False device = get_device() model_path = resolve_model_path() model = load_teacher(model_path, device) return model, device, model_path def predict_bbox(model, image: Image.Image, device: torch.device) -> Tuple[List[float], List[float]]: width, height = image.size img_t = render_full_image(image, IMG_SIZE).unsqueeze(0).to(device) with torch.no_grad(): pred = model.backbone_forward(img_t).squeeze(0).detach().cpu().clamp(0.0, 1.0).tolist() raw_xyxy = bbox_cxcywh_to_xyxy(pred, width, height) x1, y1, x2, y2 = raw_xyxy init_box = clamp_xywh( [x1, y1, max(1.0, x2 - x1), max(1.0, y2 - y1)], width, height, delta=ACTION_DELTA, ) return init_box, raw_xyxy def predict_action(model, image: Image.Image, box_xywh: List[float], device: torch.device) -> Tuple[int, List[float]]: width, height = image.size obs = render_crop(image, box_xywh, IMG_SIZE).unsqueeze(0).to(device) state = box_state(box_xywh, width, height).unsqueeze(0).to(device) with torch.no_grad(): probs, _ = model(obs, state) probs_1d = probs.squeeze(0).detach().cpu() action_idx = int(torch.distributions.Categorical(probs=probs_1d).sample().item()) return action_idx, [float(v) for v in probs_1d.tolist()] def run_policy(model, image: Image.Image, init_box: List[float], max_steps: int, device: torch.device): width, height = image.size box = list(init_box) actions = [] action_probs = [] for _ in range(max_steps): action_idx, probs = predict_action(model, image, box, device) action_name = ACTIONS[action_idx] actions.append(action_name) action_probs.append({name: round(probs[i], 4) for i, name in enumerate(ACTIONS)}) if action_name == "stop": break box = step_box(box, action_idx, width, height, delta=ACTION_DELTA) return box, actions, action_probs def draw_box(image: Image.Image, box_xywh: List[float]) -> Image.Image: out = image.copy().convert("RGB") draw = ImageDraw.Draw(out) x, y, w, h = [float(v) for v in box_xywh] x2, y2 = x + w, y + h line_width = max(3, int(min(out.size) * 0.006)) for offset in range(line_width): draw.rectangle([x - offset, y - offset, x2 + offset, y2 + offset], outline=(255, 0, 0)) return out def crop_image(image: Image.Image, box_xywh: List[float]) -> Image.Image: x, y, w, h = [float(v) for v in box_xywh] return image.crop((x, y, x + w, y + h)).convert("RGB") def infer(image, max_steps): if image is None: raise gr.Error("Please upload an image first.") image = image.convert("RGB") max_steps = int(max(0, min(200, max_steps))) model, device, model_path = get_model() init_box, raw_bbox_xyxy = predict_bbox(model, image, device) if max_steps == 0: final_box = init_box actions = [] action_probs = [] mode = "BBox head only" else: final_box, actions, action_probs = run_policy(model, image, init_box, max_steps, device) mode = "BBox head + RL policy (sampled actions)" overlay = draw_box(image, final_box) cropped = crop_image(image, final_box) info = { "mode": mode, "device": str(device), "model_path": str(model_path), "image_size": {"width": image.width, "height": image.height}, "requested_max_steps": max_steps, "actual_steps": len(actions), "stopped": bool(actions and actions[-1] == "stop"), "action_selection": "Categorical(probs).sample()", "actions": actions, "action_probs": action_probs, "initial_box_xywh": [round(float(v), 3) for v in init_box], "raw_bbox_head_xyxy": [round(float(v), 3) for v in raw_bbox_xyxy], "final_box_xywh": [round(float(v), 3) for v in final_box], } return overlay, cropped, json.dumps(info, indent=2, ensure_ascii=False) with gr.Blocks(title="Adacrop Core Policy Demo") as demo: gr.Markdown("# Adacrop Crop Demo") gr.Markdown("Upload an image. Set `max_steps = 0` to use only the BBox head; higher values run the RL policy refinement.") with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input image") max_steps = gr.Slider( minimum=0, maximum=200, step=1, value=min(max(DEFAULT_MAX_STEPS, 0), 200), label="Max RL steps", ) run_button = gr.Button("Crop", variant="primary") with gr.Column(): overlay_image = gr.Image(type="pil", label="Original image with crop box") cropped_image = gr.Image(type="pil", label="Cropped result") info = gr.Code(label="Run details", language="json") run_button.click(fn=infer, inputs=[input_image, max_steps], outputs=[overlay_image, cropped_image, info]) if __name__ == "__main__": demo.launch()