colonelwatch's picture
Add initial app.py and requirements.txt
9faa360
raw
history blame
5.55 kB
import numpy as np
from skimage import exposure, color, util
from matplotlib import pyplot as plt
import gradio as gr
# https://en.wikipedia.org/wiki/Rotation_matrix#General_rotations
def _rotation_matrix(yaw, pitch, roll):
yaw_matrix = np.array([
[np.cos(yaw), -np.sin(yaw), 0],
[np.sin(yaw), np.cos(yaw), 0],
[0, 0, 1],
])
pitch_matrix = np.array([
[np.cos(pitch), 0, np.sin(pitch)],
[0, 1, 0],
[-np.sin(pitch), 0, np.cos(pitch)],
])
roll_matrix = np.array([
[1, 0, 0],
[0, np.cos(roll), -np.sin(roll)],
[0, np.sin(roll), np.cos(roll)],
])
return yaw_matrix @ pitch_matrix @ roll_matrix
def _calculate_transform():
t_cie = np.array([50, 0, 0]) # center of CIELAB color space
# lightness axis in CIELAB space is spanned by the vector [1, 0, 0]
t_sol = np.array([55.5, -6.125, -2.875]) # center of Solarized base palette in CIELAB space
v_sol = np.array([0.951, 0.145, 0.272]) # principal component of Solarized base palette in CIELAB space
# find the rotation matrix that rotates [1, 0, 0] to v_sol
pitch = -np.arcsin(v_sol[2])
yaw = np.arcsin(v_sol[1]/np.cos(pitch))
roll = 0 # roll is a free parameter
R = _rotation_matrix(yaw, pitch, roll)
def rotate(x):
return (x-t_cie) @ R.T + t_sol
return rotate
transform = _calculate_transform()
# light_min and light_max define a range of lightnesss between 0 and 100
# chroma_attenuation is a factor between 0 and 1
def preprocess_image(image, light_min, light_max, chroma_attenutation):
lightness_range = (light_min, light_max)
chroma_range = (-128*chroma_attenutation, 128*chroma_attenutation)
image_lab = color.rgb2lab(image)
image_lab[:, :, 0] = exposure.rescale_intensity(image_lab[:, :, 0], in_range=(0, 100), out_range=lightness_range)
image_lab[:, :, 1] = exposure.rescale_intensity(image_lab[:, :, 1], in_range=(-128, 128), out_range=chroma_range)
image_lab[:, :, 2] = exposure.rescale_intensity(image_lab[:, :, 2], in_range=(-128, 128), out_range=chroma_range)
image = color.lab2rgb(image_lab)
return image
def preprocess_image_parallel(image, light_min, light_max, chroma_attenutation):
preprocess_kwargs = {'light_min': light_min, 'light_max': light_max, 'chroma_attenutation': chroma_attenutation}
image = util.apply_parallel(
preprocess_image, image,
(1024, 1024), # restricted chunk size to prevent OOM-kill
dtype=np.float64, # required, according to error message
extra_keywords=preprocess_kwargs,
channel_axis=2, # third axis holds RGB channels
)
return image
def lightness_hist(image):
fig = plt.figure(figsize=(12, 12/5)) # set aspect ratio of figure to 5:1
ax = fig.add_subplot()
image_lightness = color.rgb2lab(image)[:, :, 0].flatten()
ax.hist(image_lightness, bins=64, label=None)
ax.axvline(x=8.13974087, color='#586e75', label='Solarized dark target range')
ax.axvline(x=59.4372606, color='#586e75', label=None)
ax.axvline(x=38.76215165, color='#93a1a1', label='Solarized light target range')
ax.axvline(x=93.86995897, color='#93a1a1', label=None)
ax.set_xlim(0, 100)
ax.legend()
ax.set_xlabel('Lightness')
ax.set_ylabel('Frequency')
# set aspect ratio of final plot to 7:1 (different from figure aspect ratio to fit other elements)
x_left, x_right = ax.get_xlim()
y_bottom, y_top = ax.get_ylim()
ax.set_aspect((x_right-x_left)/(y_top-y_bottom)/7)
return fig
def transform_image(image):
shape = image.shape # record shape
workmem = color.rgb2lab(image) # convert to CIELAB
workmem = workmem.reshape(-1, 3)
workmem = transform(workmem) # transform is a function defined globally
workmem = workmem.reshape(shape) # undo flatten
workmem = color.lab2rgb(workmem) # convert back to RGB
workmem = util.img_as_ubyte(workmem) # convert back to uint8 rgb
return workmem
def transform_image_parallel(image):
image = util.apply_parallel(
transform_image, image,
(1024, 1024), # restricted chunk size to prevent OOM-kill
dtype=np.uint8, # required, according to error message
channel_axis=2, # third axis holds RGB channels
)
return image
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1, min_width=320):
input_image = gr.Image(label='Input')
light_min_slider = gr.Slider(minimum=0, maximum=100, value=10, label='Lightness minimum')
light_max_slider = gr.Slider(minimum=0, maximum=100, value=70, label='Lightness maximum')
chroma_attenutation_slider = gr.Slider(minimum=0, maximum=1, value=0.25, label='Chroma attenuation')
preprocess_button = gr.Button(value='Preprocess into workspace')
transform_button = gr.Button(value='Transform workspace')
with gr.Column(scale=2, min_width=640):
workspace_image = gr.Image(label='Workspace', interactive=False)
hist = gr.Plot(label='Lightness histogram')
preprocess_button.click(
preprocess_image_parallel,
inputs=[input_image, light_min_slider, light_max_slider, chroma_attenutation_slider],
outputs=[workspace_image]
).then(
lightness_hist,
inputs=[workspace_image],
outputs=[hist]
)
transform_button.click(
transform_image_parallel,
inputs=[workspace_image],
outputs=[workspace_image]
)
demo.launch()