Spaces:
Running on Zero
Running on Zero
| # coding: utf-8 | |
| import cv2; cv2.setNumThreads(0); cv2.ocl.setUseOpenCL(False) | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.mplot3d import Axes3D | |
| from matplotlib.patches import FancyArrowPatch | |
| from mpl_toolkits.mplot3d import proj3d | |
| def viz_lmk(img_, vps, **kwargs): | |
| """可视化点""" | |
| lineType = kwargs.get("lineType", cv2.LINE_8) # cv2.LINE_AA | |
| img_for_viz = img_.copy() | |
| for pt in vps: | |
| cv2.circle( | |
| img_for_viz, | |
| (int(pt[0]), int(pt[1])), | |
| radius=kwargs.get("radius", 1), | |
| color=(0, 255, 0), | |
| thickness=kwargs.get("thickness", 1), | |
| lineType=lineType, | |
| ) | |
| return img_for_viz | |
| def plot_3d_scatter(data, xlabel='X Label', ylabel='Y Label', zlabel='Z Label', filename=None): | |
| """ | |
| 绘制3D散点图。 | |
| 参数: | |
| data (numpy.ndarray): 形状为 (n, 3) 的数组,其中 n 是样本数量。 | |
| xlabel (str): X 轴的标签文本,默认是 'X Label'。 | |
| ylabel (str): Y 轴的标签文本,默认是 'Y Label'。 | |
| zlabel (str): Z 轴的标签文本,默认是 'Z Label'。 | |
| filename (str): 保存图像的文件名,默认为 None(不保存)。 | |
| """ | |
| # 创建一个图形实例 | |
| fig = plt.figure() | |
| # 创建一个三维的绘图工具 | |
| ax = fig.add_subplot(111, projection='3d') | |
| # 假设我们根据第三个特征来决定颜色 | |
| colors = data[:, 2] | |
| # 使用对比度高的颜色映射 | |
| cmap = 'Set1' | |
| # 绘制三维散点图 | |
| sc = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=colors, cmap=cmap) | |
| # 保存图像到文件 | |
| if filename: | |
| fig.savefig(filename, bbox_inches='tight') | |
| class Arrow3D(FancyArrowPatch): | |
| def __init__(self, xs, ys, zs, *args, **kwargs): | |
| FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) | |
| self._verts3d = xs, ys, zs | |
| def draw(self, renderer): | |
| xs3d, ys3d, zs3d = self._verts3d | |
| xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) | |
| self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) | |
| FancyArrowPatch.draw(self, renderer) | |
| def plot_vectors(points, vector_diffs, filename=None): | |
| assert points.shape[0] == vector_diffs.shape[0], "Points and vector diffs must have the same number of elements." | |
| # 创建一个图形实例 | |
| fig = plt.figure() | |
| ax = fig.add_subplot(111, projection='3d') | |
| for i in range(points.shape[0]): | |
| point = points[i] | |
| vector_diff = vector_diffs[i] | |
| # 终点位置 | |
| end_point = point + vector_diff | |
| # 在图中画出起点 | |
| ax.scatter(point[0], point[1], point[2], color='r', label='Start Point' if i == 0 else "") | |
| # 使用 quiver 画出向量 | |
| ax.quiver(point[0], point[1], point[2], | |
| vector_diff[0], vector_diff[1], vector_diff[2], | |
| length=1, normalize=True, color="g") | |
| if filename: | |
| fig.savefig(filename, bbox_inches='tight') | |
| def plot_vector_pairs(points1, points2, filename=None): | |
| assert points1.shape == points2.shape, "两点序列的形状必须一致" | |
| # 创建一个图形实例 | |
| fig = plt.figure() | |
| ax = fig.add_subplot(111, projection='3d') | |
| # 绘制散点 | |
| ax.scatter(points1[:, 0], points1[:, 1], points1[:, 2], color='r', label='Point Set 1') | |
| ax.scatter(points2[:, 0], points2[:, 1], points2[:, 2], color='b', label='Point Set 2') | |
| # 连接每对点 | |
| for i in range(points1.shape[0]): | |
| ax.quiver(points1[i, 0], points1[i, 1], points1[i, 2], | |
| points2[i, 0] - points1[i, 0], points2[i, 1] - points1[i, 1], points2[i, 2] - points1[i, 2], | |
| length=1, normalize=False, color="g") | |
| if filename: | |
| fig.savefig(filename, bbox_inches='tight') |