YOGAI / PoseClassification /visualize.py
1mpreccable's picture
Upload 35 files
0ccc9b6 verified
import io
from PIL import Image, ImageFont, ImageDraw
import requests
import matplotlib.pyplot as plt
class PoseClassificationVisualizer(object):
"""Keeps track of classifcations for every frame and renders them."""
def __init__(
self,
class_name,
plot_location_x=0.05,
plot_location_y=0.05,
plot_max_width=0.4,
plot_max_height=0.4,
plot_figsize=(9, 4),
plot_x_max=None,
plot_y_max=None,
counter_location_x=0.85,
counter_location_y=0.05,
counter_font_path="https://github.com/googlefonts/roboto/blob/main/src/hinted/Roboto-Regular.ttf?raw=true",
counter_font_color="red",
counter_font_size=0.15,
):
self._class_name = class_name
self._plot_location_x = plot_location_x
self._plot_location_y = plot_location_y
self._plot_max_width = plot_max_width
self._plot_max_height = plot_max_height
self._plot_figsize = plot_figsize
self._plot_x_max = plot_x_max
self._plot_y_max = plot_y_max
self._counter_location_x = counter_location_x
self._counter_location_y = counter_location_y
self._counter_font_path = counter_font_path
self._counter_font_color = counter_font_color
self._counter_font_size = counter_font_size
self._counter_font = None
self._pose_classification_history = []
self._pose_classification_filtered_history = []
def __call__(
self,
frame,
pose_classification,
pose_classification_filtered,
repetitions_count,
):
"""Renders pose classifcation and counter until given frame."""
# Extend classification history.
self._pose_classification_history.append(pose_classification)
self._pose_classification_filtered_history.append(pose_classification_filtered)
# Output frame with classification plot and counter.
output_img = Image.fromarray(frame)
output_width = output_img.size[0]
output_height = output_img.size[1]
# Draw the plot.
img = self._plot_classification_history(output_width, output_height)
img.thumbnail(
(
int(output_width * self._plot_max_width),
int(output_height * self._plot_max_height),
),
Image.LANCZOS,
)
output_img.paste(
img,
(
int(output_width * self._plot_location_x),
int(output_height * self._plot_location_y),
),
)
# Draw the count.
output_img_draw = ImageDraw.Draw(output_img)
if self._counter_font is None:
font_size = int(output_height * self._counter_font_size)
font_request = requests.get(self._counter_font_path, allow_redirects=True)
self._counter_font = ImageFont.truetype(
io.BytesIO(font_request.content), size=font_size
)
output_img_draw.text(
(
output_width * self._counter_location_x,
output_height * self._counter_location_y,
),
str(repetitions_count),
font=self._counter_font,
fill=self._counter_font_color,
)
return output_img
def _plot_classification_history(self, output_width, output_height):
fig = plt.figure(figsize=self._plot_figsize)
for classification_history in [
self._pose_classification_history,
self._pose_classification_filtered_history,
]:
y = []
for classification in classification_history:
if classification is None:
y.append(None)
elif self._class_name in classification:
y.append(classification[self._class_name])
else:
y.append(0)
plt.plot(y, linewidth=7)
plt.grid(axis="y", alpha=0.75)
plt.xlabel("Frame")
plt.ylabel("Confidence")
plt.title("Classification history for `{}`".format(self._class_name))
plt.legend(loc="upper right")
if self._plot_y_max is not None:
plt.ylim(top=self._plot_y_max)
if self._plot_x_max is not None:
plt.xlim(right=self._plot_x_max)
# Convert plot to image.
buf = io.BytesIO()
dpi = min(
output_width * self._plot_max_width / float(self._plot_figsize[0]),
output_height * self._plot_max_height / float(self._plot_figsize[1]),
)
fig.savefig(buf, dpi=dpi)
buf.seek(0)
img = Image.open(buf)
plt.close()
return img