Sta564 / app.py
Avi3738's picture
Update app.py
dac8171 verified
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import rgb2gray
from skimage.transform import resize
from skimage.draw import line
from skimage.filters import gaussian
import gradio as gr
import tempfile
def adjust_brightness_contrast(img, brightness=0.0, contrast=1.0):
img = img * contrast + brightness
img = np.clip(img, 0, 1)
return img
def preprocess_image(image, size=(200, 200), brightness=0.0, contrast=1.0):
if image.ndim == 3:
img_gray = rgb2gray(image)
else:
img_gray = image
img_resized = resize(img_gray, size, anti_aliasing=True)
img_adj = adjust_brightness_contrast(img_resized, brightness, contrast)
img_inverted = 1 - img_adj # black = 1, white = 0
img_inverted = (img_inverted - img_inverted.min()) / (img_inverted.max() - img_inverted.min() + 1e-8)
return img_inverted
def generate_pins(num_pins, radius=1.0):
angles = np.linspace(0, 2 * np.pi, num_pins, endpoint=False)
pins = np.array([(radius * np.cos(a), radius * np.sin(a)) for a in angles])
return pins
def draw_line(canvas, pin1, pin2, opacity=0.2):
h, w = canvas.shape
x1 = int((pin1[0] + 1) / 2 * (w - 1))
y1 = int((pin1[1] + 1) / 2 * (h - 1))
x2 = int((pin2[0] + 1) / 2 * (w - 1))
y2 = int((pin2[1] + 1) / 2 * (h - 1))
rr, cc = line(y1, x1, y2, x2)
rr = np.clip(rr, 0, h-1)
cc = np.clip(cc, 0, w-1)
canvas[rr, cc] = np.clip(canvas[rr, cc] - opacity, 0, 1)
return canvas
def blur_and_downsample(img, sigma=1.5, size=(100, 100)):
img_blur = gaussian(img, sigma=sigma)
img_small = resize(img_blur, size, anti_aliasing=True)
return img_small
def generate_connections(img, pins, num_connections, opacity=0.2):
num_pins = len(pins)
connections = []
current_pin = 0
canvas = np.ones_like(img)
target_blur = blur_and_downsample(img)
for _ in range(num_connections):
best_error = np.inf
best_pin = None
best_rr, best_cc = None, None
for next_pin in range(num_pins):
if next_pin == current_pin:
continue
h, w = canvas.shape
x1 = int((pins[current_pin][0] + 1) / 2 * (w - 1))
y1 = int((pins[current_pin][1] + 1) / 2 * (h - 1))
x2 = int((pins[next_pin][0] + 1) / 2 * (w - 1))
y2 = int((pins[next_pin][1] + 1) / 2 * (h - 1))
rr, cc = line(y1, x1, y2, x2)
rr = np.clip(rr, 0, h-1)
cc = np.clip(cc, 0, w-1)
temp_canvas = canvas.copy()
temp_canvas[rr, cc] = np.clip(temp_canvas[rr, cc] - opacity, 0, 1)
temp_blur = blur_and_downsample(temp_canvas)
error = np.sum((temp_blur - target_blur) ** 2)
if error < best_error:
best_error = error
best_pin = next_pin
best_rr, best_cc = rr, cc
if best_pin is None:
break
connections.append((current_pin, best_pin))
canvas[best_rr, best_cc] = np.clip(canvas[best_rr, best_cc] - opacity, 0, 1)
current_pin = best_pin
return connections
def draw_string_art(image, num_pins, num_connections, opacity=0.2, brightness=0.0, contrast=1.0):
img = preprocess_image(image, brightness=brightness, contrast=contrast)
pins = generate_pins(num_pins)
connections = generate_connections(img, pins, num_connections, opacity)
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_facecolor("white")
ax.axis("off")
for c in connections:
p1, p2 = pins[c[0]], pins[c[1]]
ax.plot([p1[0], p2[0]], [p1[1], p2[1]], 'k-', linewidth=0.7, alpha=opacity)
ax.scatter(pins[:,0], pins[:,1], c='red', s=16)
plt.tight_layout(pad=0)
plt.close(fig)
connections_1_based = [str(c[0]+1) for c in connections]
connections_str = ', '.join(connections_1_based)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
temp_file.write(connections_str)
temp_file.close()
return fig, temp_file.name
def interface(image, num_pins, num_connections, opacity, brightness, contrast):
fig, filename = draw_string_art(image, num_pins, num_connections, opacity, brightness, contrast)
return fig, filename
demo = gr.Interface(
fn=interface,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Slider(50, 400, 100, step=1, label="Number of Pins"),
gr.Slider(500, 10000, 1000, step=1, label="Number of Connections"),
gr.Slider(0.05, 0.5, 0.2, step=0.01, label="String Opacity (Lower = Lighter)"),
gr.Slider(-0.5, 0.5, 0.0, step=0.01, label="Brightness"),
gr.Slider(0.5, 2.0, 1.0, step=0.01, label="Contrast")
],
outputs=[
gr.Plot(label="String Art"),
gr.File(label="Connections File")
],
title="String Art Generator (Wow!Strings-like)",
description="Upload an image, set the number of pins, connections, string opacity, brightness, and contrast. The generator will create a string art image that closely resembles your input and a downloadable connection file."
)
if __name__ == "__main__":
demo.launch()