adacrop-demo / app.py
zzsyppt's picture
Update app.py
5a321ee verified
import json
import os
import sys
from functools import lru_cache
from pathlib import Path
from typing import List, Tuple
import gradio as gr
import torch
from PIL import Image, ImageDraw
SPACE_DIR = Path(__file__).resolve().parent
PROJECT_DIR = SPACE_DIR.parent
for path in (SPACE_DIR, PROJECT_DIR):
if str(path) not in sys.path:
sys.path.insert(0, str(path))
try:
from distillation.common import (
ACTIONS,
bbox_cxcywh_to_xyxy,
box_state,
clamp_xywh,
load_teacher,
render_crop,
render_full_image,
step_box,
)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Cannot import distillation.common. Deploy this demo together with the "
"Adacrop/distillation directory, or copy distillation/common.py into the Space repo."
) from exc
IMG_SIZE = int(os.getenv("IMG_SIZE", "224"))
ACTION_DELTA = float(os.getenv("ACTION_DELTA", "0.05"))
DEFAULT_MAX_STEPS = int(os.getenv("DEFAULT_MAX_STEPS", "60"))
MODEL_ENV = os.getenv("MODEL_PATH", "ppo_best_val_final_score.pth")
def resolve_model_path() -> Path:
raw = Path(MODEL_ENV)
candidates = []
if raw.is_absolute():
candidates.append(raw)
candidates.extend(
[
SPACE_DIR / raw,
PROJECT_DIR / raw,
SPACE_DIR / "models" / raw.name,
PROJECT_DIR / "models" / raw.name,
]
)
for candidate in candidates:
if candidate.exists():
return candidate
checked = "\n".join(str(p) for p in candidates)
raise FileNotFoundError(
f"Could not find model checkpoint {MODEL_ENV!r}. Checked:\n{checked}\n"
"Put ppo_best_val_final_score.pth in the Space root, or set MODEL_PATH."
)
def get_device() -> torch.device:
if os.getenv("FORCE_CPU", "0") == "1":
return torch.device("cpu")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
@lru_cache(maxsize=1)
def get_model():
if os.getenv("DISABLE_CUDNN", "0") == "1":
torch.backends.cudnn.enabled = False
device = get_device()
model_path = resolve_model_path()
model = load_teacher(model_path, device)
return model, device, model_path
def predict_bbox(model, image: Image.Image, device: torch.device) -> Tuple[List[float], List[float]]:
width, height = image.size
img_t = render_full_image(image, IMG_SIZE).unsqueeze(0).to(device)
with torch.no_grad():
pred = model.backbone_forward(img_t).squeeze(0).detach().cpu().clamp(0.0, 1.0).tolist()
raw_xyxy = bbox_cxcywh_to_xyxy(pred, width, height)
x1, y1, x2, y2 = raw_xyxy
init_box = clamp_xywh(
[x1, y1, max(1.0, x2 - x1), max(1.0, y2 - y1)],
width,
height,
delta=ACTION_DELTA,
)
return init_box, raw_xyxy
def predict_action(model, image: Image.Image, box_xywh: List[float], device: torch.device) -> Tuple[int, List[float]]:
width, height = image.size
obs = render_crop(image, box_xywh, IMG_SIZE).unsqueeze(0).to(device)
state = box_state(box_xywh, width, height).unsqueeze(0).to(device)
with torch.no_grad():
probs, _ = model(obs, state)
probs_1d = probs.squeeze(0).detach().cpu()
action_idx = int(torch.distributions.Categorical(probs=probs_1d).sample().item())
return action_idx, [float(v) for v in probs_1d.tolist()]
def run_policy(model, image: Image.Image, init_box: List[float], max_steps: int, device: torch.device):
width, height = image.size
box = list(init_box)
actions = []
action_probs = []
for _ in range(max_steps):
action_idx, probs = predict_action(model, image, box, device)
action_name = ACTIONS[action_idx]
actions.append(action_name)
action_probs.append({name: round(probs[i], 4) for i, name in enumerate(ACTIONS)})
if action_name == "stop":
break
box = step_box(box, action_idx, width, height, delta=ACTION_DELTA)
return box, actions, action_probs
def draw_box(image: Image.Image, box_xywh: List[float]) -> Image.Image:
out = image.copy().convert("RGB")
draw = ImageDraw.Draw(out)
x, y, w, h = [float(v) for v in box_xywh]
x2, y2 = x + w, y + h
line_width = max(3, int(min(out.size) * 0.006))
for offset in range(line_width):
draw.rectangle([x - offset, y - offset, x2 + offset, y2 + offset], outline=(255, 0, 0))
return out
def crop_image(image: Image.Image, box_xywh: List[float]) -> Image.Image:
x, y, w, h = [float(v) for v in box_xywh]
return image.crop((x, y, x + w, y + h)).convert("RGB")
def infer(image, max_steps):
if image is None:
raise gr.Error("Please upload an image first.")
image = image.convert("RGB")
max_steps = int(max(0, min(200, max_steps)))
model, device, model_path = get_model()
init_box, raw_bbox_xyxy = predict_bbox(model, image, device)
if max_steps == 0:
final_box = init_box
actions = []
action_probs = []
mode = "BBox head only"
else:
final_box, actions, action_probs = run_policy(model, image, init_box, max_steps, device)
mode = "BBox head + RL policy (sampled actions)"
overlay = draw_box(image, final_box)
cropped = crop_image(image, final_box)
info = {
"mode": mode,
"device": str(device),
"model_path": str(model_path),
"image_size": {"width": image.width, "height": image.height},
"requested_max_steps": max_steps,
"actual_steps": len(actions),
"stopped": bool(actions and actions[-1] == "stop"),
"action_selection": "Categorical(probs).sample()",
"actions": actions,
"action_probs": action_probs,
"initial_box_xywh": [round(float(v), 3) for v in init_box],
"raw_bbox_head_xyxy": [round(float(v), 3) for v in raw_bbox_xyxy],
"final_box_xywh": [round(float(v), 3) for v in final_box],
}
return overlay, cropped, json.dumps(info, indent=2, ensure_ascii=False)
with gr.Blocks(title="Adacrop Core Policy Demo") as demo:
gr.Markdown("# Adacrop Crop Demo")
gr.Markdown("Upload an image. Set `max_steps = 0` to use only the BBox head; higher values run the RL policy refinement.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input image")
max_steps = gr.Slider(
minimum=0,
maximum=200,
step=1,
value=min(max(DEFAULT_MAX_STEPS, 0), 200),
label="Max RL steps",
)
run_button = gr.Button("Crop", variant="primary")
with gr.Column():
overlay_image = gr.Image(type="pil", label="Original image with crop box")
cropped_image = gr.Image(type="pil", label="Cropped result")
info = gr.Code(label="Run details", language="json")
run_button.click(fn=infer, inputs=[input_image, max_steps], outputs=[overlay_image, cropped_image, info])
if __name__ == "__main__":
demo.launch()