Spaces:
Sleeping
Sleeping
| import io | |
| from PIL import Image, ImageFont, ImageDraw | |
| import requests | |
| import matplotlib.pyplot as plt | |
| class PoseClassificationVisualizer(object): | |
| """Keeps track of classifcations for every frame and renders them.""" | |
| def __init__( | |
| self, | |
| class_name, | |
| plot_location_x=0.05, | |
| plot_location_y=0.05, | |
| plot_max_width=0.4, | |
| plot_max_height=0.4, | |
| plot_figsize=(9, 4), | |
| plot_x_max=None, | |
| plot_y_max=None, | |
| counter_location_x=0.85, | |
| counter_location_y=0.05, | |
| counter_font_path="https://github.com/googlefonts/roboto/blob/main/src/hinted/Roboto-Regular.ttf?raw=true", | |
| counter_font_color="red", | |
| counter_font_size=0.15, | |
| ): | |
| self._class_name = class_name | |
| self._plot_location_x = plot_location_x | |
| self._plot_location_y = plot_location_y | |
| self._plot_max_width = plot_max_width | |
| self._plot_max_height = plot_max_height | |
| self._plot_figsize = plot_figsize | |
| self._plot_x_max = plot_x_max | |
| self._plot_y_max = plot_y_max | |
| self._counter_location_x = counter_location_x | |
| self._counter_location_y = counter_location_y | |
| self._counter_font_path = counter_font_path | |
| self._counter_font_color = counter_font_color | |
| self._counter_font_size = counter_font_size | |
| self._counter_font = None | |
| self._pose_classification_history = [] | |
| self._pose_classification_filtered_history = [] | |
| def __call__( | |
| self, | |
| frame, | |
| pose_classification, | |
| pose_classification_filtered, | |
| repetitions_count, | |
| ): | |
| """Renders pose classifcation and counter until given frame.""" | |
| # Extend classification history. | |
| self._pose_classification_history.append(pose_classification) | |
| self._pose_classification_filtered_history.append(pose_classification_filtered) | |
| # Output frame with classification plot and counter. | |
| output_img = Image.fromarray(frame) | |
| output_width = output_img.size[0] | |
| output_height = output_img.size[1] | |
| # Draw the plot. | |
| img = self._plot_classification_history(output_width, output_height) | |
| img.thumbnail( | |
| ( | |
| int(output_width * self._plot_max_width), | |
| int(output_height * self._plot_max_height), | |
| ), | |
| Image.LANCZOS, | |
| ) | |
| output_img.paste( | |
| img, | |
| ( | |
| int(output_width * self._plot_location_x), | |
| int(output_height * self._plot_location_y), | |
| ), | |
| ) | |
| # Draw the count. | |
| output_img_draw = ImageDraw.Draw(output_img) | |
| if self._counter_font is None: | |
| font_size = int(output_height * self._counter_font_size) | |
| font_request = requests.get(self._counter_font_path, allow_redirects=True) | |
| self._counter_font = ImageFont.truetype( | |
| io.BytesIO(font_request.content), size=font_size | |
| ) | |
| output_img_draw.text( | |
| ( | |
| output_width * self._counter_location_x, | |
| output_height * self._counter_location_y, | |
| ), | |
| str(repetitions_count), | |
| font=self._counter_font, | |
| fill=self._counter_font_color, | |
| ) | |
| return output_img | |
| def _plot_classification_history(self, output_width, output_height): | |
| fig = plt.figure(figsize=self._plot_figsize) | |
| for classification_history in [ | |
| self._pose_classification_history, | |
| self._pose_classification_filtered_history, | |
| ]: | |
| y = [] | |
| for classification in classification_history: | |
| if classification is None: | |
| y.append(None) | |
| elif self._class_name in classification: | |
| y.append(classification[self._class_name]) | |
| else: | |
| y.append(0) | |
| plt.plot(y, linewidth=7) | |
| plt.grid(axis="y", alpha=0.75) | |
| plt.xlabel("Frame") | |
| plt.ylabel("Confidence") | |
| plt.title("Classification history for `{}`".format(self._class_name)) | |
| plt.legend(loc="upper right") | |
| if self._plot_y_max is not None: | |
| plt.ylim(top=self._plot_y_max) | |
| if self._plot_x_max is not None: | |
| plt.xlim(right=self._plot_x_max) | |
| # Convert plot to image. | |
| buf = io.BytesIO() | |
| dpi = min( | |
| output_width * self._plot_max_width / float(self._plot_figsize[0]), | |
| output_height * self._plot_max_height / float(self._plot_figsize[1]), | |
| ) | |
| fig.savefig(buf, dpi=dpi) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img | |