File size: 15,835 Bytes
fc8df74
40598e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8df74
 
40598e4
fc8df74
 
40598e4
fc8df74
 
 
 
 
 
 
 
 
 
 
 
 
 
40598e4
fc8df74
 
 
 
 
40598e4
fc8df74
 
 
 
 
 
 
 
 
 
40598e4
fc8df74
 
 
 
 
 
 
 
 
 
 
 
40598e4
 
 
 
 
 
09cf71e
fc8df74
 
40598e4
fc8df74
 
 
 
40598e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8df74
 
 
40598e4
fc8df74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40598e4
fc8df74
40598e4
 
 
fc8df74
 
 
40598e4
fc8df74
40598e4
fc8df74
40598e4
 
fc8df74
 
40598e4
fc8df74
40598e4
fc8df74
40598e4
 
fc8df74
 
 
 
 
 
40598e4
 
fc8df74
 
 
 
 
 
 
40598e4
 
 
 
 
 
 
 
 
 
 
 
 
fc8df74
40598e4
 
 
 
fc8df74
 
40598e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc8df74
 
40598e4
 
 
fc8df74
 
 
40598e4
fc8df74
 
 
 
 
 
 
 
 
 
 
 
 
 
40598e4
 
 
fc8df74
 
40598e4
 
fc8df74
 
40598e4
fc8df74
40598e4
 
 
 
fc8df74
 
 
40598e4
fc8df74
 
40598e4
fc8df74
 
 
 
40598e4
 
fc8df74
 
 
 
 
 
40598e4
fc8df74
 
40598e4
fc8df74
 
 
 
 
 
 
40598e4
fc8df74
 
 
 
 
 
40598e4
 
fc8df74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40598e4
fc8df74
 
f479f2f
 
 
8384f92
 
 
 
 
0e73e98
 
 
 
fc8df74
 
 
 
09cf71e
5f9c46d
61793df
fc8df74
40598e4
09cf71e
afd1f17
 
fc8df74
40598e4
afd1f17
40598e4
 
 
fc8df74
 
 
 
 
 
f479f2f
 
0e73e98
f479f2f
 
 
0e73e98
 
 
7c57ed0
 
f479f2f
0e73e98
09cf71e
40598e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f479f2f
40598e4
fc8df74
 
 
40598e4
09cf71e
 
40598e4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
"""
Gradio app to replicate the interactive vanishing-point selection tool
from the supplied matplotlib script, implemented for gradio==3.50.2.

How it works (UI):
- Upload an image.
- Click "Start Yellow" or "Start Red" to enter a drawing mode for that line group.
- Click on the image to add points. Two consecutive clicks make a line.
- You can add as many lines as you want for each color.
- Press "Compute vanishing points" to run optimization (scipy.minimize) for
  each color group and display the vanishing points and overlayed lines.
- Reset clears all state.

Requirements:
- gradio==3.50.2
- numpy
- scipy
- pillow

Run:
    pip install gradio==3.50.2 numpy scipy pillow
    python grad_io_gradio_app.py

Note: This implementation uses the Image.select event which behaves correctly
in gradio 3.50.2 (it provides pixel coordinates of the clicked point). If you
use a newer Gradio version, the event behavior might differ.

"""

import io
import math
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
from scipy.optimize import minimize

# ------------------------ Helper math functions ---------------------------

def build_line_from_points(p1, p2):
    """Return line coefficients (A, B, C) for Ax + By + C = 0 given two points."""
    x1, y1 = p1
    x2, y2 = p2
    a = y1 - y2
    b = x2 - x1
    c = x1 * y2 - y1 * x2
    return np.array([a, b, c], dtype=float)


def distance_point_to_line(pt, line):
    x, y = pt
    a, b, c = line
    return abs(a * x + b * y + c) / math.hypot(a, b)


def total_distances(x, lines, noise_lines):
    """Sum of distances from candidate point x to all lines and noise lines."""
    pt = x
    s = 0.0
    for L in lines:
        s += distance_point_to_line(pt, L)
    for Ln in noise_lines:
        s += distance_point_to_line(pt, Ln)
    return s


def add_noise_lines_for_line(p1, p2, n=4, sigma=1.0):
    """Create a list of "noise" lines by jittering the endpoints slightly."""
    noise_lines = []
    for _ in range(n):
        p1n = (p1[0] + np.random.normal(0, sigma), p1[1] + np.random.normal(0, sigma))
        p2n = (p2[0] + np.random.normal(0, sigma), p2[1] + np.random.normal(0, sigma))
        noise_lines.append(build_line_from_points(p1n, p2n))
    return noise_lines

# ------------------------- Drawing utilities ------------------------------

def draw_overlay(base_pil, yellow_lines, red_lines, yellow_points, red_points, vps=None):
    """Return a new PIL image with overlays drawn: lines, points and vanishing points.

    - yellow_lines, red_lines: lists of line coefficients
    - yellow_points, red_points: lists of tuples (p1, p2) for each line
    - vps: dict with keys 'yellow' and 'red' for vanishing points (x,y)
    """
    img = base_pil.copy().convert("RGBA")
    draw = ImageDraw.Draw(img)

    # helpers
    def draw_point(pt, color, r=4):
        x, y = pt
        draw.ellipse((x - r, y - r, x + r, y + r), fill=color, outline=color)

    def draw_line_by_points(p1, p2, color, width=2, dash=False):
        # we just draw a straight segment connecting endpoints
        if dash:
            # dashed line: draw small segments
            x1, y1 = p1
            x2, y2 = p2
            segs = 40
            for i in range(segs):
                t0 = i / segs
                t1 = (i + 0.5) / segs
                xa = x1 * (1 - t0) + x2 * t0
                ya = y1 * (1 - t0) + y2 * t0
                xb = x1 * (1 - t1) + x2 * t1
                yb = y1 * (1 - t1) + y2 * t1
                draw.line((xa, ya, xb, yb), fill=color, width=width)
        else:
            draw.line((p1[0], p1[1], p2[0], p2[1]), fill=color, width=width)

    # Draw yellow lines
    for idx, ((p1, p2), L) in enumerate(zip(yellow_points, yellow_lines)):
        # draw long extents of line by projecting to image bounds
        draw_line_segment_from_line(L, img.size, color=(255, 215, 0, 200), draw=draw)
        draw_point(p1, (255, 215, 0, 255))
        draw_point(p2, (255, 215, 0, 255))

    # Draw red lines
    for idx, ((p1, p2), L) in enumerate(zip(red_points, red_lines)):
        draw_line_segment_from_line(L, img.size, color=(255, 64, 64, 200), draw=draw)
        draw_point(p1, (255, 64, 64, 255))
        draw_point(p2, (255, 64, 64, 255))

    # Draw vanishing points if present
    if vps is not None:
        if "yellow" in vps and vps["yellow"] is not None:
            draw_point(vps["yellow"], (255, 215, 0, 255), r=6)
        if "red" in vps and vps["red"] is not None:
            draw_point(vps["red"], (255, 64, 64, 255), r=6)

    return img.convert("RGB")


def draw_line_segment_from_line(line, image_size, draw=None, color=(255, 255, 0, 255)):
    """Given line coefficients and image size, draw a segment across the image bounds.
    This draws directly using ImageDraw if 'draw' is provided.
    """
    W, H = image_size
    a, b, c = line
    points = []
    # intersection with left edge x=0
    if abs(b) > 1e-9:
        y = -(a * 0 + c) / b
        points.append((0, y))
    # right edge x=W
    if abs(b) > 1e-9:
        y = -(a * W + c) / b
        points.append((W, y))
    # top edge y=0 --> a x + c = 0
    if abs(a) > 1e-9:
        x = -(b * 0 + c) / a
        points.append((x, 0))
    # bottom edge y=H
    if abs(a) > 1e-9:
        x = -(b * H + c) / a
        points.append((x, H))

    # keep only points within the image bounds
    pts_in = [(x, y) for (x, y) in points if -W * 0.1 <= x <= W * 1.1 and -H * 0.1 <= y <= H * 1.1]
    if len(pts_in) >= 2 and draw is not None:
        # pick two extreme points
        # sort by x coordinate
        pts_in = sorted(pts_in, key=lambda p: (p[0], p[1]))
        pA = pts_in[0]
        pB = pts_in[-1]
        draw.line((pA[0], pA[1], pB[0], pB[1]), fill=color, width=2)

# ------------------------- Gradio app callbacks ---------------------------

# We'll store states in gr.State objects:
# - current_mode: None | 'yellow' | 'red'
# - current_points: list of pending points (len 0 or 1 waiting for second click)
# - yellow_lines: list of (A,B,C)
# - red_lines: list of (A,B,C)
# - yellow_points_pairs: list of ((p1,p2))
# - red_points_pairs: list of ((p1,p2))


def init_states():
    return None, [], [], [], [], []


def on_mode_change(mode, image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs):
    """Switch drawing mode between 'yellow', 'red' or None.
    Returns image (unchanged) and updated states.
    """
    # Just update the mode state. Clear any pending single point.
    return (image, mode, [], y_lines, r_lines, y_pairs, r_pairs)


def on_image_select(sel: gr.SelectData, image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs):
    """Called when user clicks on the image. sel.index gives (x, y) in pixels.

    We append the point, and when there are 2 points we form a line and add to the
    corresponding color list. We then redraw overlays and return the updated image and states.
    """
    # sel may contain relative coords depending on gradio version; here we expect .index
    if sel is None:
        return image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs

    idx = getattr(sel, "index", None)
    # Some versions wrap coordinates as [x, y], some as (x, y)
    if idx is None:
        # fallback: try .data or .value
        idx = getattr(sel, "data", None) or getattr(sel, "value", None)
    if not idx:
        return image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs

    x, y = int(idx[0]), int(idx[1])

    # append to current_points
    current_points = list(current_points) if current_points is not None else []
    current_points.append((x, y))

    # if we have two points, create a line
    if len(current_points) >= 2 and current_mode in ("yellow", "red"):
        p1 = current_points[-2]
        p2 = current_points[-1]
        L = build_line_from_points(p1, p2)
        if current_mode == "yellow":
            y_lines = list(y_lines) if y_lines is not None else []
            y_pairs = list(y_pairs) if y_pairs is not None else []
            y_lines.append(L)
            y_pairs.append((p1, p2))
        else:
            r_lines = list(r_lines) if r_lines is not None else []
            r_pairs = list(r_pairs) if r_pairs is not None else []
            r_lines.append(L)
            r_pairs.append((p1, p2))

    # redraw overlay image
    base_pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image
    out = draw_overlay(base_pil, y_lines or [], r_lines or [], y_pairs or [], r_pairs or [], vps=None)

    return out, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs


def compute_vanishing_points(image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs):
    """Compute vanishing points for both color groups, draw them and return annotated image.

    For each group: if there are >1 lines, compute intersections and use mean intersection
    as initial guess; then minimize sum of distances to lines + noise-lines.
    """
    img_pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image

    vps = {"yellow": None, "red": None}

    # process yellow group
    if y_lines and len(y_lines) > 1:
        lines_arr = np.array(y_lines)
        # intersections
        inters = []
        for i in range(len(lines_arr) - 1):
            for j in range(i + 1, len(lines_arr)):
                try:
                    ip = np.linalg.solve(np.array([[lines_arr[i][0], lines_arr[i][1]],[lines_arr[j][0], lines_arr[j][1]]]),
                                         -np.array([lines_arr[i][2], lines_arr[j][2]]))
                    inters.append(ip)
                except Exception:
                    pass
        if inters:
            p0 = np.mean(inters, axis=0)
        else:
            # fallback: center of image
            p0 = np.array([img_pil.width / 2, img_pil.height / 2])

        # noise lines
        noise = []
        for (p1, p2) in y_pairs:
            noise += add_noise_lines_for_line(p1, p2, n=4, sigma=2.0)

        res = minimize(lambda x: total_distances(x, lines_arr, noise), p0, method='Powell')
        vps['yellow'] = (float(res.x[0]), float(res.x[1]))

    # process red group
    if r_lines and len(r_lines) > 1:
        lines_arr = np.array(r_lines)
        inters = []
        for i in range(len(lines_arr) - 1):
            for j in range(i + 1, len(lines_arr)):
                try:
                    ip = np.linalg.solve(np.array([[lines_arr[i][0], lines_arr[i][1]],[lines_arr[j][0], lines_arr[j][1]]]),
                                         -np.array([lines_arr[i][2], lines_arr[j][2]]))
                    inters.append(ip)
                except Exception:
                    pass
        if inters:
            p0 = np.mean(inters, axis=0)
        else:
            p0 = np.array([img_pil.width / 2, img_pil.height / 2])

        noise = []
        for (p1, p2) in r_pairs:
            noise += add_noise_lines_for_line(p1, p2, n=4, sigma=2.0)

        res = minimize(lambda x: total_distances(x, lines_arr, noise), p0, method='Powell')
        vps['red'] = (float(res.x[0]), float(res.x[1]))

    out = draw_overlay(img_pil, y_lines or [], r_lines or [], y_pairs or [], r_pairs or [], vps=vps)
    return out, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs


def reset_all(image, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs, original_image):
    """Reset all state and return the original base image without any overlays."""
    # Use original_image if available, otherwise use the current image
    #if original_image is not None:
    base_pil = Image.fromarray(original_image) if not isinstance(original_image, Image.Image) else original_image
    #else:
    #    # Fallback: use current image (might have overlays, but better than nothing)
    #    base_pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image
    
    # Convert PIL image to numpy array for Gradio (since type="numpy")
    base_np = np.array(base_pil.convert("RGB"))
    return base_np, None, [], [], [], [], []

# ------------------------------ Build Blocks ------------------------------

def build_gradio_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# πŸŒ•πŸŒ– Shadow Consistency Analysis πŸŒ—πŸŒ˜πŸŒ‘")
        gr.Markdown("A utility for verifying geometric consistency of shadows in an image. By projecting vanishing points, it helps determine if all shadows correspond to a single, coherent light source. This method is based on principles of perspective and can be useful for analyzing both traditional manipulations and AI-generated images.")
        with gr.Row():
            img_in = gr.Image(label="Upload image and then click to add points", type="numpy", interactive=True, height=800)
            with gr.Column():
                start_y = gr.Button("Start Yellow Line")
                start_r = gr.Button("Start Red Line")
                none_btn = gr.Button("Stop Drawing")
                compute_btn = gr.Button("Compute vanishing points")
                reset_btn = gr.Button("Reset Figure and Clear Lines")
                gr.Markdown("\nClick the image to add points. Two points => one line. Add at least 2 lines per group to compute a vanishing point.")

        # states
        current_mode = gr.State(None)
        current_points = gr.State([])
        y_lines = gr.State([])
        r_lines = gr.State([])
        y_pairs = gr.State([])
        r_pairs = gr.State([])
        original_image = gr.State(None)  # Store original base image without overlays

        # Store original image when uploaded (not when programmatically changed)
        def store_original(img, orig):
            """Store the original image when a new image is uploaded."""
            if img is not None:
                # Make a copy to ensure it doesn't get modified
                if isinstance(img, np.ndarray):
                    return img.copy()
                return img  # Store the new image as original
            return orig  # Keep existing original if no new image

        img_in.upload(store_original, inputs=[img_in, original_image], outputs=[original_image])

        # link buttons to mode change
        start_y.click(on_mode_change, inputs=[gr.State("yellow"), img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
                      outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs])
        start_r.click(on_mode_change, inputs=[gr.State("red"), img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
                      outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs])
        none_btn.click(on_mode_change, inputs=[gr.State(None), img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
                      outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs])

        # image select event
        img_in.select(on_image_select, inputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
                      outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs])

        compute_btn.click(compute_vanishing_points, inputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs],
                          outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs])

        reset_btn.click(reset_all, inputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs, original_image],
                        outputs=[img_in, current_mode, current_points, y_lines, r_lines, y_pairs, r_pairs])

    return demo


if __name__ == '__main__':
    demo = build_gradio_interface()
    demo.queue()
    demo.launch()