Spaces:
Running
Running
| import json | |
| import os | |
| import sys | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw | |
| SPACE_DIR = Path(__file__).resolve().parent | |
| PROJECT_DIR = SPACE_DIR.parent | |
| for path in (SPACE_DIR, PROJECT_DIR): | |
| if str(path) not in sys.path: | |
| sys.path.insert(0, str(path)) | |
| try: | |
| from distillation.common import ( | |
| ACTIONS, | |
| bbox_cxcywh_to_xyxy, | |
| box_state, | |
| clamp_xywh, | |
| load_teacher, | |
| render_crop, | |
| render_full_image, | |
| step_box, | |
| ) | |
| except ModuleNotFoundError as exc: | |
| raise ModuleNotFoundError( | |
| "Cannot import distillation.common. Deploy this demo together with the " | |
| "Adacrop/distillation directory, or copy distillation/common.py into the Space repo." | |
| ) from exc | |
| IMG_SIZE = int(os.getenv("IMG_SIZE", "224")) | |
| ACTION_DELTA = float(os.getenv("ACTION_DELTA", "0.05")) | |
| DEFAULT_MAX_STEPS = int(os.getenv("DEFAULT_MAX_STEPS", "60")) | |
| MODEL_ENV = os.getenv("MODEL_PATH", "ppo_best_val_final_score.pth") | |
| def resolve_model_path() -> Path: | |
| raw = Path(MODEL_ENV) | |
| candidates = [] | |
| if raw.is_absolute(): | |
| candidates.append(raw) | |
| candidates.extend( | |
| [ | |
| SPACE_DIR / raw, | |
| PROJECT_DIR / raw, | |
| SPACE_DIR / "models" / raw.name, | |
| PROJECT_DIR / "models" / raw.name, | |
| ] | |
| ) | |
| for candidate in candidates: | |
| if candidate.exists(): | |
| return candidate | |
| checked = "\n".join(str(p) for p in candidates) | |
| raise FileNotFoundError( | |
| f"Could not find model checkpoint {MODEL_ENV!r}. Checked:\n{checked}\n" | |
| "Put ppo_best_val_final_score.pth in the Space root, or set MODEL_PATH." | |
| ) | |
| def get_device() -> torch.device: | |
| if os.getenv("FORCE_CPU", "0") == "1": | |
| return torch.device("cpu") | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def get_model(): | |
| if os.getenv("DISABLE_CUDNN", "0") == "1": | |
| torch.backends.cudnn.enabled = False | |
| device = get_device() | |
| model_path = resolve_model_path() | |
| model = load_teacher(model_path, device) | |
| return model, device, model_path | |
| def predict_bbox(model, image: Image.Image, device: torch.device) -> Tuple[List[float], List[float]]: | |
| width, height = image.size | |
| img_t = render_full_image(image, IMG_SIZE).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| pred = model.backbone_forward(img_t).squeeze(0).detach().cpu().clamp(0.0, 1.0).tolist() | |
| raw_xyxy = bbox_cxcywh_to_xyxy(pred, width, height) | |
| x1, y1, x2, y2 = raw_xyxy | |
| init_box = clamp_xywh( | |
| [x1, y1, max(1.0, x2 - x1), max(1.0, y2 - y1)], | |
| width, | |
| height, | |
| delta=ACTION_DELTA, | |
| ) | |
| return init_box, raw_xyxy | |
| def predict_action(model, image: Image.Image, box_xywh: List[float], device: torch.device) -> Tuple[int, List[float]]: | |
| width, height = image.size | |
| obs = render_crop(image, box_xywh, IMG_SIZE).unsqueeze(0).to(device) | |
| state = box_state(box_xywh, width, height).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| probs, _ = model(obs, state) | |
| probs_1d = probs.squeeze(0).detach().cpu() | |
| action_idx = int(torch.distributions.Categorical(probs=probs_1d).sample().item()) | |
| return action_idx, [float(v) for v in probs_1d.tolist()] | |
| def run_policy(model, image: Image.Image, init_box: List[float], max_steps: int, device: torch.device): | |
| width, height = image.size | |
| box = list(init_box) | |
| actions = [] | |
| action_probs = [] | |
| for _ in range(max_steps): | |
| action_idx, probs = predict_action(model, image, box, device) | |
| action_name = ACTIONS[action_idx] | |
| actions.append(action_name) | |
| action_probs.append({name: round(probs[i], 4) for i, name in enumerate(ACTIONS)}) | |
| if action_name == "stop": | |
| break | |
| box = step_box(box, action_idx, width, height, delta=ACTION_DELTA) | |
| return box, actions, action_probs | |
| def draw_box(image: Image.Image, box_xywh: List[float]) -> Image.Image: | |
| out = image.copy().convert("RGB") | |
| draw = ImageDraw.Draw(out) | |
| x, y, w, h = [float(v) for v in box_xywh] | |
| x2, y2 = x + w, y + h | |
| line_width = max(3, int(min(out.size) * 0.006)) | |
| for offset in range(line_width): | |
| draw.rectangle([x - offset, y - offset, x2 + offset, y2 + offset], outline=(255, 0, 0)) | |
| return out | |
| def crop_image(image: Image.Image, box_xywh: List[float]) -> Image.Image: | |
| x, y, w, h = [float(v) for v in box_xywh] | |
| return image.crop((x, y, x + w, y + h)).convert("RGB") | |
| def infer(image, max_steps): | |
| if image is None: | |
| raise gr.Error("Please upload an image first.") | |
| image = image.convert("RGB") | |
| max_steps = int(max(0, min(200, max_steps))) | |
| model, device, model_path = get_model() | |
| init_box, raw_bbox_xyxy = predict_bbox(model, image, device) | |
| if max_steps == 0: | |
| final_box = init_box | |
| actions = [] | |
| action_probs = [] | |
| mode = "BBox head only" | |
| else: | |
| final_box, actions, action_probs = run_policy(model, image, init_box, max_steps, device) | |
| mode = "BBox head + RL policy (sampled actions)" | |
| overlay = draw_box(image, final_box) | |
| cropped = crop_image(image, final_box) | |
| info = { | |
| "mode": mode, | |
| "device": str(device), | |
| "model_path": str(model_path), | |
| "image_size": {"width": image.width, "height": image.height}, | |
| "requested_max_steps": max_steps, | |
| "actual_steps": len(actions), | |
| "stopped": bool(actions and actions[-1] == "stop"), | |
| "action_selection": "Categorical(probs).sample()", | |
| "actions": actions, | |
| "action_probs": action_probs, | |
| "initial_box_xywh": [round(float(v), 3) for v in init_box], | |
| "raw_bbox_head_xyxy": [round(float(v), 3) for v in raw_bbox_xyxy], | |
| "final_box_xywh": [round(float(v), 3) for v in final_box], | |
| } | |
| return overlay, cropped, json.dumps(info, indent=2, ensure_ascii=False) | |
| with gr.Blocks(title="Adacrop Core Policy Demo") as demo: | |
| gr.Markdown("# Adacrop Crop Demo") | |
| gr.Markdown("Upload an image. Set `max_steps = 0` to use only the BBox head; higher values run the RL policy refinement.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input image") | |
| max_steps = gr.Slider( | |
| minimum=0, | |
| maximum=200, | |
| step=1, | |
| value=min(max(DEFAULT_MAX_STEPS, 0), 200), | |
| label="Max RL steps", | |
| ) | |
| run_button = gr.Button("Crop", variant="primary") | |
| with gr.Column(): | |
| overlay_image = gr.Image(type="pil", label="Original image with crop box") | |
| cropped_image = gr.Image(type="pil", label="Cropped result") | |
| info = gr.Code(label="Run details", language="json") | |
| run_button.click(fn=infer, inputs=[input_image, max_steps], outputs=[overlay_image, cropped_image, info]) | |
| if __name__ == "__main__": | |
| demo.launch() | |