AMontiB
upload
09cf71e
raw
history blame
11.8 kB
"""
Gradio app for shadow analysis with vanishing-point selection tool
Updated for Gradio 4.x compatibility
"""
import math
import numpy as np
from PIL import Image, ImageDraw
import gradio as gr
from scipy.optimize import minimize
# ------------------------ Helper math functions ---------------------------
def build_line_from_points(p1, p2):
"""Return line coefficients (A, B, C) for Ax + By + C = 0 given two points."""
x1, y1 = p1
x2, y2 = p2
a = y1 - y2
b = x2 - x1
c = x1 * y2 - y1 * x2
return np.array([a, b, c], dtype=float)
def distance_point_to_line(pt, line):
x, y = pt
a, b, c = line
return abs(a * x + b * y + c) / math.hypot(a, b)
def total_distances(x, lines, noise_lines):
"""Sum of distances from candidate point x to all lines and noise lines."""
pt = x
s = 0.0
for L in lines:
s += distance_point_to_line(pt, L)
for Ln in noise_lines:
s += distance_point_to_line(pt, Ln)
return s
def add_noise_lines_for_line(p1, p2, n=4, sigma=1.0):
"""Create a list of "noise" lines by jittering the endpoints slightly."""
noise_lines = []
for _ in range(n):
p1n = (p1[0] + np.random.normal(0, sigma), p1[1] + np.random.normal(0, sigma))
p2n = (p2[0] + np.random.normal(0, sigma), p2[1] + np.random.normal(0, sigma))
noise_lines.append(build_line_from_points(p1n, p2n))
return noise_lines
# ------------------------- Drawing utilities ------------------------------
def draw_overlay(base_pil, yellow_lines, red_lines, yellow_points, red_points, vps=None):
"""Return a new PIL image with overlays drawn: lines, points and vanishing points."""
img = base_pil.copy().convert("RGBA")
draw = ImageDraw.Draw(img)
def draw_point(pt, color, r=4):
x, y = pt
draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline=color)
def draw_line_by_points(p1, p2, color, width=2):
draw.line((p1[0], p1[1], p2[0], p2[1]), fill=color, width=width)
# Draw yellow lines
for idx, ((p1, p2), L) in enumerate(zip(yellow_points, yellow_lines)):
draw_line_segment_from_line(L, img.size, color=(255, 215, 0, 200), draw=draw)
draw_point(p1, (255, 215, 0, 255))
draw_point(p2, (255, 215, 0, 255))
# Draw red lines
for idx, ((p1, p2), L) in enumerate(zip(red_points, red_lines)):
draw_line_segment_from_line(L, img.size, color=(255, 64, 64, 200), draw=draw)
draw_point(p1, (255, 64, 64, 255))
draw_point(p2, (255, 64, 64, 255))
# Draw vanishing points if present
if vps is not None:
if "yellow" in vps and vps["yellow"] is not None:
draw_point(vps["yellow"], (255, 215, 0, 255), r=6)
if "red" in vps and vps["red"] is not None:
draw_point(vps["red"], (255, 64, 64, 255), r=6)
return img.convert("RGB")
def draw_line_segment_from_line(line, image_size, draw=None, color=(255, 255, 0, 255)):
"""Given line coefficients and image size, draw a segment across the image bounds."""
W, H = image_size
a, b, c = line
points = []
# intersection with edges
if abs(b) > 1e-9:
y = -c / b
points.append((0, y))
y = -(a * W + c) / b
points.append((W, y))
if abs(a) > 1e-9:
x = -c / a
points.append((x, 0))
x = -(b * H + c) / a
points.append((x, H))
# keep only points within the image bounds
pts_in = [(x, y) for (x, y) in points if -W * 0.1 <= x <= W * 1.1 and -H * 0.1 <= y <= H * 1.1]
if len(pts_in) >= 2 and draw is not None:
pts_in = sorted(pts_in, key=lambda p: (p[0], p[1]))
pA = pts_in[0]
pB = pts_in[-1]
draw.line((pA[0], pA[1], pB[0], pB[1]), fill=color, width=2)
# ------------------------- Gradio app callbacks ---------------------------
def on_mode_change(mode, image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs):
"""Switch drawing mode between 'yellow', 'red' or None."""
return (image, mode, [], y_lines, r_lines, y_pairs, r_pairs)
def on_image_click(evt: gr.SelectData, image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs):
"""Called when user clicks on the image in Gradio 4.x"""
if image is None:
return image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs
x, y = evt.index
# Convert to list if needed
current_points = list(current_points) if current_points is not None else []
current_points.append((x, y))
# If we have two points, create a line
if len(current_points) >= 2 and current_mode in ("yellow", "red"):
p1 = current_points[-2]
p2 = current_points[-1]
L = build_line_from_points(p1, p2)
if current_mode == "yellow":
y_lines = list(y_lines) if y_lines is not None else []
y_pairs = list(y_pairs) if y_pairs is not None else []
y_lines.append(L)
y_pairs.append((p1, p2))
else:
r_lines = list(r_lines) if r_lines is not None else []
r_pairs = list(r_pairs) if r_pairs is not None else []
r_lines.append(L)
r_pairs.append((p1, p2))
# Reset current points for next line
current_points = []
# Redraw overlay
if isinstance(image, np.ndarray):
base_pil = Image.fromarray(image)
else:
base_pil = image
out = draw_overlay(base_pil, y_lines or [], r_lines or [], y_pairs or [], r_pairs or [], vps=None)
out_np = np.array(out)
return out_np, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs
def compute_vanishing_points(image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs):
"""Compute vanishing points for both color groups."""
if image is None:
return image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs
if isinstance(image, np.ndarray):
img_pil = Image.fromarray(image)
else:
img_pil = image
vps = {"yellow": None, "red": None}
# Process yellow group
if y_lines and len(y_lines) > 1:
lines_arr = np.array(y_lines)
inters = []
for i in range(len(lines_arr) - 1):
for j in range(i + 1, len(lines_arr)):
try:
ip = np.linalg.solve(
np.array([[lines_arr[i][0], lines_arr[i][1]], [lines_arr[j][0], lines_arr[j][1]]]),
-np.array([lines_arr[i][2], lines_arr[j][2]])
)
inters.append(ip)
except Exception:
pass
if inters:
p0 = np.mean(inters, axis=0)
else:
p0 = np.array([img_pil.width / 2, img_pil.height / 2])
noise = []
for (p1, p2) in y_pairs:
noise += add_noise_lines_for_line(p1, p2, n=4, sigma=2.0)
res = minimize(lambda x: total_distances(x, lines_arr, noise), p0, method='Powell')
vps['yellow'] = (float(res.x[0]), float(res.x[1]))
# Process red group
if r_lines and len(r_lines) > 1:
lines_arr = np.array(r_lines)
inters = []
for i in range(len(lines_arr) - 1):
for j in range(i + 1, len(lines_arr)):
try:
ip = np.linalg.solve(
np.array([[lines_arr[i][0], lines_arr[i][1]], [lines_arr[j][0], lines_arr[j][1]]]),
-np.array([lines_arr[i][2], lines_arr[j][2]])
)
inters.append(ip)
except Exception:
pass
if inters:
p0 = np.mean(inters, axis=0)
else:
p0 = np.array([img_pil.width / 2, img_pil.height / 2])
noise = []
for (p1, p2) in r_pairs:
noise += add_noise_lines_for_line(p1, p2, n=4, sigma=2.0)
res = minimize(lambda x: total_distances(x, lines_arr, noise), p0, method='Powell')
vps['red'] = (float(res.x[0]), float(res.x[1]))
out = draw_overlay(img_pil, y_lines or [], r_lines or [], y_pairs or [], r_pairs or [], vps=vps)
out_np = np.array(out)
return out_np, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs
def reset_all(image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs):
"""Reset all states."""
if image is not None:
if isinstance(image, np.ndarray):
return image, None, [], [], [], [], []
else:
return np.array(image), None, [], [], [], [], []
return image, None, [], [], [], [], []
# ------------------------------ Build Blocks ------------------------------
def build_gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# Shadow Analysis - Vanishing Point Detection")
with gr.Row():
img_in = gr.Image(
label="Upload image and click to add points",
type="numpy",
interactive=True,
height=600
)
with gr.Column():
start_y = gr.Button("Start Yellow Lines")
start_r = gr.Button("Start Red Lines")
none_btn = gr.Button("Stop Drawing")
compute_btn = gr.Button("Compute Vanishing Points")
reset_btn = gr.Button("Reset All")
gr.Markdown("""
**Instructions:**
1. Upload an image
2. Click 'Start Yellow' or 'Start Red' to choose line color
3. Click on the image to add points (2 points = 1 line)
4. Add at least 2 lines per color group
5. Click 'Compute Vanishing Points' to analyze
""")
# State variables
current_mode = gr.State(None)
current_points = gr.State([])
y_lines = gr.State([])
r_lines = gr.State([])
y_pairs = gr.State([])
r_pairs = gr.State([])
# Event handlers
start_y.click(
fn=on_mode_change,
inputs=[gr.State("yellow"), img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs]
)
start_r.click(
fn=on_mode_change,
inputs=[gr.State("red"), img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs]
)
none_btn.click(
fn=on_mode_change,
inputs=[gr.State(None), img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs]
)
# Image click event - updated for Gradio 4.x
img_in.select(
fn=on_image_click,
inputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs]
)
compute_btn.click(
fn=compute_vanishing_points,
inputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs]
)
reset_btn.click(
fn=reset_all,
inputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs]
)
return demo
if __name__ == '__main__':
demo = build_gradio_interface()
demo.launch()