File size: 9,447 Bytes
857c2e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# https://zhuanlan.zhihu.com/p/418311842
from matplotlib import pyplot as plt
from matplotlib import transforms
from math import sin, cos, atan, pi, sqrt, tan
import numpy as np
from types import MethodType  # 用于更改axes实例方法

plt.rc('font', family='Times New Roman')


class arrow():
    '''
    arrow()向图中添加一个箭头,准确说是没有尾巴的箭头
    参数:
            ax:要绘制的坐标区的句柄
            x:起点和终点的x坐标
            y:起点和终点的y坐标
            size:绘制的箭头大小(箭头的基本大小是不依赖于坐标区的大小的)
            angle:箭头顶端的张角(默认40度)
            d_frac:控制绘制的箭头的尾部中部连接点距顶点的距离是标准状态的d_frac倍,影响箭头样式
            color:箭头填充的颜色
            style:可选样式为full / above / below,分别为标准样式、上三角、下三角
            **kw:用于plt.fill的其他自定义参数
    '''

    def __init__(self, ax, x, y, size, angle=40, d_frac=1, color='#1f77b4', style='full', **kw):
        self.x = x
        self.y = y
        self.size = size
        self.angle = angle
        self.d_frac = d_frac
        self.color = color
        self.style = style
        self.kw = kw
        self.plt_arrow(ax)

    def plt_arrow(self, ax):
        base_d = 0.1 * self.size
        d = base_d * self.d_frac
        half_h = base_d * tan(self.angle / 360 * pi)
        X = [-d, -base_d, 0, -base_d, -d]
        if self.style == 'full':
            Y = [0, half_h, 0, -half_h, 0]
        elif self.style == 'above':
            Y = [0, half_h, 0, 0, 0]
        elif self.style == 'below':
            Y = [0, 0, 0, -half_h, 0]
        else:
            raise ValueError('style的值应为full/above/below')
        trans = (ax.figure.dpi_scale_trans + transforms.ScaledTranslation(self.x[1], self.y[1],
                                                                          ax.transData))  # 设定两个变换,一个使绘制出的箭头大小不依赖于坐标系而是屏幕,一个将箭头移至坐标系中的指定坐标
        theta = self.calc_theta(ax, self.x, self.y)
        coords = np.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]]) @ (np.array([X, Y]) + np.array(
            [[0.01], [0]]))  # 将坐标和旋转矩阵相乘,加上的偏移量[[0.01],[0]]目的是解决绘制的箭头处于线条终点时顶端两侧由于线宽导致的像素露出
        self.handle = ax.fill(coords[0, :], coords[1, :], color=self.color, **self.kw, transform=trans)[0]

    @staticmethod
    def calc_theta(ax, x, y):
        '''
        计算两点连线和x轴的夹角
        此函数计算的夹角依赖于屏幕而不是坐标系
        参数:
                ax:要计算绝对角度的坐标系
                x:起始点和终点的x坐标
                y:起始点和终点的y坐标
        '''
        #         bbox = ax.get_window_extent().transformed(ax.figure.dpi_scale_trans.inverted()) #https://stackoverflow.com/questions/19306510/determine-matplotlib-axis-size-in-pixels,获取坐标区的bbox尺寸
        #         ax_width=ax.axis()[1]-ax.axis()[0]
        #         ax_height=ax.axis()[3]-ax.axis()[2]
        #         t_w=(x[1]-x[0])/ax_width*bbox.width #获取线条的‘绝对宽度’
        #         t_h=(y[1]-y[0])/ax_height*bbox.height
        p1 = list(ax.transData.transform([x[0], y[0]]))  # 将坐标转为像素坐标
        p2 = list(ax.transData.transform([x[1], y[1]]))
        t_h = (p2[1] - p1[1])
        t_w = (p2[0] - p1[0])
        if t_h > 0 and t_w == 0:
            theta = pi / 2
        elif t_h < 0 and t_w == 0:
            theta = -pi / 2
        elif t_h == 0 and t_w == 0:
            raise ValueError('请不要设置起始点和终点重合')
        else:
            theta = atan(t_h / t_w)  # 计算图线的‘绝对夹角’
            if t_h < 0 and t_w < 0:
                theta += pi
            elif t_h > 0 and t_w < 0:
                theta += pi
        return theta


class arrowline():
    '''
    arrowline类用于向坐标区中添加带箭头的线,可以是多点的曲线,目前只支持添加仅有一组x,y坐标的曲线,不过应该够用了,更复杂的绘图需求在此之上调用封装就好
    参数:
            style:添加箭头的模式,可选参数:to、to_back、middle、equal_d
                to:在终点添加
                to_back:在终点和起点各添加一个朝向两端的箭头
                middle:在曲线中间点位置添加箭头
                equal_d:在曲线上每隔n个点添加一个箭头
            interval:设置equal_d模式下添加箭头的点间隔
            self.arrows储存了所有箭头的句柄
            self.callback是matplotlib RsizeEvent回调的编号,该回调用于命令行运行脚本后得到的图形界面中实时查看时,在缩放窗口时箭头方位能正确变化
    '''

    def __init__(self, ax, x, y, cosys='rect', arrow_size=1, style='to', arrow_angle=40, d_frac=1, color='#1f77b4',
                 arrow_style='full', interval=50, **kw):
        self.arrows = []
        arrowline.set_ax(ax)  # 对当前ax进行初始化设置,增加一些属性,对axes实例的方法进行一些增添修改
        if cosys == 'rect' or cosys == 'polar':
            self.handle = ax.plot(x, y, color=color, **kw)[0]
        elif cosys == 'loglog':
            self.handle = ax.loglog(x, y, color=color, **kw)[0]

        x, y, l = list(x), list(y), len(x)
        kw = {'angle': arrow_angle, 'd_frac': d_frac, 'color': color, 'style': arrow_style}
        if style == 'to':
            l -= 1
            self.arrows.append(arrow(ax, [x[l - 1], x[l]], [y[l - 1], y[l]], arrow_size, **kw))
        elif style == 'to_back':
            l -= 1
            self.arrows.append(arrow(ax, [x[l - 1], x[l]], [y[l - 1], y[l]], arrow_size, **kw))
            self.arrows.append(arrow(ax, [x[1], x[0]], [y[1], y[0]], arrow_size, **kw))
        elif style == 'middle':
            if l % 2 == 0:
                l = int(l / 2)
                self.arrows.append(
                    arrow(ax, [x[l - 1], (x[l - 1] + x[l]) / 2], [y[l - 1], (y[l - 1] + y[l]) / 2], arrow_size, **kw))
            else:
                l = int(l / 2)
                self.arrows.append(arrow(ax, [x[l - 1], x[l]], [y[l - 1], y[l]], arrow_size, **kw))
        elif style == 'equal_d':
            if type(interval) != type(1):
                raise ValueError('interval应为正整数')
            if interval <= 0:
                raise ValueError('请不要设置interval为小于等于零的数')
            i = interval
            while i < l:
                self.arrows.append(arrow(ax, [x[i - 1], x[i]], [y[i - 1], y[i]], arrow_size, **kw))
                i += interval
        else:
            raise ValueError('请检查style参数是否正确')
        self.callback = ax
        ax.arrowlines.append(self)

    def remove(self):  # 彻底移除arrowline
        self.handle.axes.arrowlines.remove(self)
        self.handle.remove()
        for i in self.arrows:
            i.handle.remove()
            del i
        del self

    @property
    def callback(self):
        return self._callback

    @callback.setter
    def callback(self, ax):
        def resized(event):
            for i in self.arrows:
                i.handle.remove()
                i.plt_arrow(ax)

        self._callback = ax.figure.canvas.mpl_connect('resize_event', resized)

    @staticmethod
    def set_ax(ax):
        try:
            ax.prev_xlim
        except AttributeError:  # 如果不存在prev_xlim属性,则进行初始化,避免重复初始化
            ax.prev_xlim = ax.get_xlim()
            ax.prev_ylim = ax.get_ylim()
            ax.arrowlines = []

            def check(self):
                xlim = self.get_xlim()
                ylim = self.get_ylim()
                if self.prev_xlim != xlim or self.prev_ylim != ylim:
                    self.prev_xlim = xlim
                    self.prev_ylim = ylim
                    for obj in self.arrowlines:
                        for i in obj.arrows:
                            i.handle.remove()
                            i.plt_arrow(ax)

            ax.check_xylim = MethodType(check, ax)
            ax.clear1 = ax.clear  # 防止循环调用

            def new_clear(self):
                self.arrowlines = []
                return self.clear1()

            ax.clear = MethodType(new_clear, ax)  # 更改ax原有的clear()方法以实现清空画布时同步清空arrowlines列表,防止清空画布后遗留的arrowlines会重复绘制
            ax._request_autoscale_view1 = ax._request_autoscale_view

            def new_request_autoscale_view(self, tight=None, scalex=True, scaley=True):
                args = self._request_autoscale_view1(tight,
                                                     scalex, scaley
                                                     )
                self.check_xylim()
                return args

            ax._request_autoscale_view = MethodType(new_request_autoscale_view,
                                                    ax)  # 更改原有的_request_autoscale_view方法以使得每次坐标区的视图状态改变时箭头方向都能及时更新