balthou's picture
update doc / comments
ec40011
import numpy as np
import argparse
from PIL import Image
from interactive_pipe import (
interactive_pipeline, interactive, Curve, SingleCurve
)
from transformers import pipeline
# ----------------------------
# Processing blocks definition
# ----------------------------
@interactive(
background_color=("green", ["green", "blue", "red"]),
border_size=(0.05, [0., 0.3]),
)
def generate_background(
background_color: str = "green", # dropdown menu (str)
border_size: float = 0. # continuous slider (float)
) -> np.ndarray:
out = np.zeros((256, 256, 3)) # Initial background set to black
border_int = int(border_size * 256)
out[
border_int:out.shape[0]-border_int,
border_int:out.shape[1]-border_int,
["red", "green", "blue"].index(background_color)
] = 0.5
return out
@interactive(
radius=(0.005, [0., 0.01]),
spread=(1., [0., 2.]),
geometric_shape=("snooker", ["snooker", "circle", "traffic light"]),
)
def add_circles(
background: np.ndarray,
radius: float = 0., # continuous slider (float)
spread: float = 1., # continuous slider (float)
geometric_shape: str = "snooker", # dropdown menu (str)
) -> np.ndarray:
out = background.copy() # Perform a copy to avoid inplace modifications!
x, y = np.meshgrid(
np.linspace(-1, 1, out.shape[1]), np.linspace(-1, 1, out.shape[0]))
balls = [
((0., 0.3), [0.8, 0.8, 0.8]), # Cue ball (white)
((0.0, -0.6), [1, 1, 0]),
((-0.15, -0.85), [1, 0, 0]),
((0.0, -0.85), [0, 1, 0]),
((0.15, -0.85), [1, 0, 0]),
((-0.075, -0.725), [1, 1, 0]),
((0.075, -0.725), [1, 0, 0]),
]
circle_clock = [
((0.7*np.cos(angle), 0.7*np.sin(angle)), [1, 1, 0])
for angle in np.linspace(0, 2*np.pi, 12)
]
traffic_light = [
((0.0, 0.0), [1, 0.8, 0]),
((0.0, 0.12), [0, 1, 0]),
((0.0, -0.12), [1, 0, 0])
]
chosen_pattern = {"circle": circle_clock, "snooker": balls,
"traffic light": traffic_light}[geometric_shape]
for (cx, cy), color in chosen_pattern:
r = (x - spread*cx) ** 2 + (y - spread*cy) ** 2
out[r < radius, :] = color
return out
@interactive(add_stick=(False, "Add black rectangle"))
def add_details(img: np.ndarray, add_stick: bool = False) -> np.ndarray:
out = img.copy()
x, y = np.meshgrid(
np.linspace(-1, 1, out.shape[1]), np.linspace(-1, 1, out.shape[0]))
if add_stick:
# out[(np.abs(x)+0.5*np.abs(y)) < 0.3] = 0. # [0.8, 0.8, 0.]
mask = (np.abs(x) < 0.1) * (0.75*np.abs(y) < 0.2)
out[mask, :] = 0.
return out
@interactive(
noise_level=(0.05, [0., 0.2]),
seed=(42, [-1, 100])
)
def add_noise(img: np.ndarray, noise_level: float = 0., seed: int = 42):
if seed > 0:
# If you do not set the seed, the noise will be different at each call
# So changing any slider value will change the noise pattern...
# This is something you want to avoid in practice in graphical user interfaces!
np.random.seed(seed)
return (img + np.random.normal(0, noise_level, img.shape)).clip(0., 1.)
@interactive(detect=(True, "Enable classification"))
def apply_classifier(
img: np.ndarray,
context: dict = {},
detect: bool = False
) -> None:
if detect:
if not context.get("clf", None):
context["clf"] = pipeline(
"image-classification",
model="google/vit-base-patch16-224"
)
# Context is used to store the classification pipeline
# and avoid reloading it.
result = context["clf"](Image.fromarray((img*255).astype(np.uint8)))
else:
result = [{"score": 0., "label": "No classification"}]*5
# Context is shared between all interactive blocks.
# We also store the classification result inside this dictionary
# We do not return classification results.
# as these are not image/audio buffers!
# In display_result, we'll show some curves based
# on the classification results.
context["result"] = result
def display_result(context: dict = {}) -> Curve:
# Context is shared between all interactive blocks.
# We can access the classification result here.
result_dict = context.get("result", [])
curves = [
SingleCurve(
x=[id, id, id+1, id+1],
y=[0, r['score'], r['score'], 0],
label=r["label"], linestyle="-",
)
for id, r in enumerate(result_dict)]
result_curves = Curve(
curves,
ylim=[0, 1],
title=f"{result_dict[0]['label']} ({result_dict[0]['score']:.2%})"
)
return result_curves
# -------------------
# Pipeline definition
# -------------------
def classification_tutorial_pipeline():
background = generate_background()
foreground = add_details(background)
foreground = add_circles(foreground)
noisy_input = add_noise(foreground)
apply_classifier(noisy_input)
result_curve = display_result()
return [[background, foreground], [noisy_input, result_curve]]
# ----------------------------------------------------------
# Main:
# allows choosing backend through the command line `-b qt`
# ----------------------------------------------------------
if __name__ == "__main__":
BACKEND_OPTIONS = ["gradio", "qt", "mpl"]
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--backend", type=str,
choices=BACKEND_OPTIONS, default=BACKEND_OPTIONS[0])
args = parser.parse_args()
md_description = "# 🔍 EXAMPLE Interactive-pipe + machine learning \n"
md_description += "```python\n"+open(__file__, 'r').read()+"```\n"
classification_tutorial_pipeline_interactive = interactive_pipeline(
gui=args.backend,
markdown_description=md_description,
)(classification_tutorial_pipeline)
classification_tutorial_pipeline_interactive()