ryo2 commited on
Commit
61a6540
·
verified ·
1 Parent(s): 590edae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -255
app.py CHANGED
@@ -1,278 +1,278 @@
1
- import gradio as gr
2
- import polars as pl
3
- import os
4
- import shutil
5
- import numpy as np
6
- from PyQt6.QtWidgets import QApplication
7
- from PyQt6 import QtCore
8
- import pyqtgraph as pg
9
- from pyqtgraph.exporters import ImageExporter
10
 
11
- class DataProcessor:
12
- def __init__(self, bodypart_names, x_max, y_max):
13
- # 余分な空白を除去してリスト化
14
- self.bodypart_names = [name.strip() for name in bodypart_names.split(',')]
15
- self.x_max = x_max
16
- self.y_max = y_max
17
- self.output_folder = 'output_plots'
18
- if not os.path.exists(self.output_folder):
19
- os.makedirs(self.output_folder)
20
 
21
- def process_csv(self, file_path):
22
- # CSVをpolarsで読み込み(ヘッダーはなし)
23
- df_raw = pl.read_csv(file_path, has_header=False)
24
- # 2行分のヘッダー(インデックス1,2)を取得し、最初の列は除外する
25
- header1 = df_raw.row(1)[1:]
26
- header2 = df_raw.row(2)[1:]
27
- new_columns = [f"{h1}|{h2}" for h1, h2 in zip(header1, header2)]
28
- # データ部分はインデックス3以降(0-indexed)とし、先頭列を削除
29
- df_data = df_raw.slice(3)
30
- first_col = df_data.columns[0]
31
- df_data = df_data.drop(first_col)
32
- df_data.columns = new_columns
33
 
34
- # likelihood列のみ抽出
35
- df_likelihood = self.extract_likelihood(df_data)
36
- # likelihood列を除去したデータ
37
- df_no_likelihood = self.remove_likelihood(df_data)
38
- # 付属肢名の置換(左側の名前を mapping で変更)
39
- df_renamed = self.rename_bodyparts(df_no_likelihood)
40
- return df_renamed, df_likelihood
41
 
42
- def remove_likelihood(self, df):
43
- # 列名が "bodypart|likelihood" となっている列を除外
44
- new_cols = [col for col in df.columns if col.split("|")[1] != "likelihood"]
45
- return df.select(new_cols)
46
 
47
- def rename_bodyparts(self, df):
48
- cols = df.columns
49
- current_names = []
50
- for col in cols:
51
- bp = col.split("|")[0]
52
- if bp not in current_names:
53
- current_names.append(bp)
54
- if len(self.bodypart_names) != len(current_names):
55
- raise ValueError("The length of bodypart_names must be equal to the number of bodyparts.")
56
- mapping = dict(zip(current_names, self.bodypart_names))
57
- new_cols = {col: f"{mapping[col.split('|')[0]]}|{col.split('|')[1]}" for col in cols}
58
- return df.rename(new_cols)
59
 
60
- def extract_likelihood(self, df):
61
- # likelihood列のみを抽出
62
- likelihood_cols = [col for col in df.columns if col.split("|")[1] == "likelihood"]
63
- df_likelihood = df.select(likelihood_cols)
64
- current_names = []
65
- for col in likelihood_cols:
66
- bp = col.split("|")[0]
67
- if bp not in current_names:
68
- current_names.append(bp)
69
- if len(self.bodypart_names) != len(current_names):
70
- raise ValueError("The length of bodypart_names must be equal to the number of bodyparts.")
71
- mapping = dict(zip(current_names, self.bodypart_names))
72
- new_cols = {col: f"{mapping[col.split('|')[0]]}|{col.split('|')[1]}" for col in likelihood_cols}
73
- return df_likelihood.rename(new_cols)
74
 
75
- def get_bodyparts(self, df):
76
- bodyparts = []
77
- for col in df.columns:
78
- bp = col.split("|")[0]
79
- if bp not in bodyparts:
80
- bodyparts.append(bp)
81
- return bodyparts
82
 
83
- def plot_scatter(self, df):
84
- image_paths = []
85
- bodyparts = self.get_bodyparts(df)
86
- app = QApplication.instance()
87
- if app is None:
88
- app = QApplication([])
89
 
90
- # 個別の散布図を作成
91
- for i, bodypart in enumerate(bodyparts):
92
- try:
93
- x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
94
- y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
95
- except Exception as e:
96
- continue
97
- pw = pg.PlotWidget(title=f'トラッキングの座標({bodypart})')
98
- pw.setLabel('bottom', 'X Coordinate(pixel)')
99
- pw.setLabel('left', 'Y Coordinate(pixel)')
100
- pw.setXRange(0, self.x_max)
101
- pw.setYRange(0, self.y_max)
102
- pw.invertY(True)
103
- color = pg.intColor(i, len(bodyparts))
104
- # 散布図アイテムの追加
105
- scatter = pg.ScatterPlotItem(x=x, y=y, pen=pg.mkPen(color=color), symbol='o', brush=color)
106
- pw.addItem(scatter)
107
- # 始点を黒丸でハイライトし、"Start"テキストを追加
108
- if len(x) > 0:
109
- scatter_start = pg.ScatterPlotItem(x=[x[0]], y=[y[0]], pen=pg.mkPen(color='k'), symbol='o', size=10, brush='k')
110
- pw.addItem(scatter_start)
111
- text = pg.TextItem("Start", color='k')
112
- text.setPos(x[0], y[0])
113
- pw.addItem(text)
114
- # PNGにエクスポート
115
- exporter = ImageExporter(pw.plotItem)
116
- filename = os.path.join(self.output_folder, f"{bodypart}.png")
117
- exporter.export(filename)
118
- image_paths.append(filename)
119
 
120
- # 全付属肢の散布図を作成
121
- pw_all = pg.PlotWidget(title='トラッキングの座標(全付属肢)')
122
- pw_all.setLabel('bottom', 'X Coordinate(pixel)')
123
- pw_all.setLabel('left', 'Y Coordinate(pixel)')
124
- pw_all.setXRange(0, self.x_max)
125
- pw_all.setYRange(0, self.y_max)
126
- pw_all.invertY(True)
127
- for i, bodypart in enumerate(bodyparts):
128
- try:
129
- x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
130
- y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
131
- except Exception as e:
132
- continue
133
- color = pg.intColor(i, len(bodyparts))
134
- scatter = pg.ScatterPlotItem(x=x, y=y, pen=pg.mkPen(color=color), symbol='o', brush=color)
135
- pw_all.addItem(scatter)
136
- exporter_all = ImageExporter(pw_all.plotItem)
137
- filename_all = os.path.join(self.output_folder, "all_plot.png")
138
- exporter_all.export(filename_all)
139
- image_paths.append(filename_all)
140
- return image_paths
141
 
142
- def plot_trajectories(self, df):
143
- image_paths = []
144
- bodyparts = self.get_bodyparts(df)
145
- app = QApplication.instance()
146
- if app is None:
147
- app = QApplication([])
148
 
149
- # 個別の軌跡図を作成
150
- for i, bodypart in enumerate(bodyparts):
151
- try:
152
- x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
153
- y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
154
- except Exception as e:
155
- continue
156
- pw = pg.PlotWidget(title=f'トラッキングの座標({bodypart})')
157
- pw.setLabel('bottom', 'Frames')
158
- pw.setLabel('left', 'Coordinate(pixel)')
159
- pen_x = pg.mkPen(color=pg.intColor(i, len(bodyparts)), style=QtCore.Qt.PenStyle.DashLine)
160
- pen_y = pg.mkPen(color=pg.intColor(i, len(bodyparts)))
161
- pw.plot(x, pen=pen_x, name=f"{bodypart}(x座標)")
162
- pw.plot(y, pen=pen_y, name=f"{bodypart}(y座標)")
163
- exporter = ImageExporter(pw.plotItem)
164
- filename = os.path.join(self.output_folder, f"{bodypart}_trajectories.png")
165
- exporter.export(filename)
166
- image_paths.append(filename)
167
 
168
- # 全付属肢の軌跡図を作成
169
- pw_all = pg.PlotWidget(title='トラッキングの座標(全付属肢)')
170
- pw_all.setLabel('bottom', 'Frames')
171
- pw_all.setLabel('left', 'Coordinate(pixel)')
172
- for i, bodypart in enumerate(bodyparts):
173
- try:
174
- x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
175
- y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
176
- except Exception as e:
177
- continue
178
- pen_x = pg.mkPen(color=pg.intColor(i, len(bodyparts)), style=QtCore.Qt.PenStyle.DashLine)
179
- pen_y = pg.mkPen(color=pg.intColor(i, len(bodyparts)))
180
- pw_all.plot(x, pen=pen_x, name=f"{bodypart}(x座標)")
181
- pw_all.plot(y, pen=pen_y, name=f"{bodypart}(y座標)")
182
- exporter_all = ImageExporter(pw_all.plotItem)
183
- filename_all = os.path.join(self.output_folder, "all_trajectories.png")
184
- exporter_all.export(filename_all)
185
- image_paths.append(filename_all)
186
- return image_paths
187
 
188
- def plot_likelihood(self, df_likelihood):
189
- image_paths = []
190
- bodyparts = self.get_bodyparts(df_likelihood)
191
- app = QApplication.instance()
192
- if app is None:
193
- app = QApplication([])
194
 
195
- # 付属肢ごとの尤度グラフを作成
196
- for i, bodypart in enumerate(bodyparts):
197
- try:
198
- likelihood = np.array(df_likelihood[f"{bodypart}|likelihood"].to_list(), dtype=float)
199
- except Exception as e:
200
- continue
201
- pw = pg.PlotWidget(title=f'フレーム別の尤度 ({bodypart})')
202
- pw.setLabel('bottom', 'Frames')
203
- pw.setLabel('left', '尤度')
204
- pw.setYRange(0, 1.0)
205
- color = pg.intColor(i, len(bodyparts))
206
- pw.plot(likelihood, pen=pg.mkPen(color=color), name=bodypart)
207
- exporter = ImageExporter(pw.plotItem)
208
- filename = os.path.join(self.output_folder, f"{bodypart}_likelihood.png")
209
- exporter.export(filename)
210
- image_paths.append(filename)
211
 
212
- # 全付属肢の尤度グラフを作成
213
- pw_all = pg.PlotWidget(title='フレーム別の尤度 (全付属肢)')
214
- pw_all.setLabel('bottom', 'Frames')
215
- pw_all.setLabel('left', '尤度')
216
- pw_all.setYRange(0, 1.0)
217
- for i, bodypart in enumerate(bodyparts):
218
- try:
219
- likelihood = np.array(df_likelihood[f"{bodypart}|likelihood"].to_list(), dtype=float)
220
- except Exception as e:
221
- continue
222
- color = pg.intColor(i, len(bodyparts))
223
- pw_all.plot(likelihood, pen=pg.mkPen(color=color), name=bodypart)
224
- exporter_all = ImageExporter(pw_all.plotItem)
225
- filename_all = os.path.join(self.output_folder, "likelihood_plot.png")
226
- exporter_all.export(filename_all)
227
- image_paths.append(filename_all)
228
- return image_paths
229
 
230
- class GradioInterface:
231
- def __init__(self):
232
- self.interface = gr.Interface(
233
- fn=self.process_and_plot,
234
- inputs=[
235
- gr.File(label="CSVファイルをドラッグ&ドロップ"),
236
- gr.Textbox(
237
- label="付属肢の名前(カンマ区切り)",
238
- value="指節1, 指節2, 指節3, 指節4, 指節5, 指節6, 指節7, 指節8, 指節9, 指節10, 指節11, 指節12, 指節13, 指節14, 触角(左), 触角(右), 頭部, 腹尾節"
239
- ),
240
- gr.Number(label="X軸の最大値", value=1920),
241
- gr.Number(label="Y軸の最大値", value=1080),
242
- gr.CheckboxGroup(
243
- label="プロットするグラフを選択",
244
- choices=["散布図", "軌跡図", "尤度グラフ"],
245
- value=["散布図", "軌跡図", "尤度グラフ"],
246
- type="value"
247
- )
248
- ],
249
- outputs=[
250
- gr.Gallery(label="散布図"),
251
- gr.File(label="ZIPダウンロード")
252
- ],
253
- title="DeepLabCutグラフ出力ツール",
254
- description="CSVファイルからグラフを作成します。"
255
- )
256
 
257
- def process_and_plot(self, file, bodypart_names, x_max, y_max, graph_choices):
258
- processor = DataProcessor(bodypart_names, x_max, y_max)
259
- df, df_likelihood = processor.process_csv(file.name)
260
 
261
- all_image_paths = []
262
- if "散布図" in graph_choices:
263
- all_image_paths += processor.plot_scatter(df)
264
- if "軌跡図" in graph_choices:
265
- all_image_paths += processor.plot_trajectories(df)
266
- if "尤度グラフ" in graph_choices:
267
- all_image_paths += processor.plot_likelihood(df_likelihood)
268
 
269
- shutil.make_archive(processor.output_folder, 'zip', processor.output_folder)
270
- return all_image_paths, processor.output_folder + '.zip'
271
 
272
- def launch(self):
273
- self.interface.launch()
274
 
275
 
276
- if __name__ == "__main__":
277
- gradio_app = GradioInterface()
278
- gradio_app.launch()
 
1
+ import gradio as gr
2
+ import polars as pl
3
+ import os
4
+ import shutil
5
+ import numpy as np
6
+ from PyQt6.QtWidgets import QApplication
7
+ from PyQt6 import QtCore
8
+ import pyqtgraph as pg
9
+ from pyqtgraph.exporters import ImageExporter
10
 
11
+ class DataProcessor:
12
+ def __init__(self, bodypart_names, x_max, y_max):
13
+ # 余分な空白を除去してリスト化
14
+ self.bodypart_names = [name.strip() for name in bodypart_names.split(',')]
15
+ self.x_max = x_max
16
+ self.y_max = y_max
17
+ self.output_folder = 'output_plots'
18
+ if not os.path.exists(self.output_folder):
19
+ os.makedirs(self.output_folder)
20
 
21
+ def process_csv(self, file_path):
22
+ # CSVをpolarsで読み込み(ヘッダーはなし)
23
+ df_raw = pl.read_csv(file_path, has_header=False)
24
+ # 2行分のヘッダー(インデックス1,2)を取得し、最初の列は除外する
25
+ header1 = df_raw.row(1)[1:]
26
+ header2 = df_raw.row(2)[1:]
27
+ new_columns = [f"{h1}|{h2}" for h1, h2 in zip(header1, header2)]
28
+ # データ部分はインデックス3以降(0-indexed)とし、先頭列を削除
29
+ df_data = df_raw.slice(3)
30
+ first_col = df_data.columns[0]
31
+ df_data = df_data.drop(first_col)
32
+ df_data.columns = new_columns
33
 
34
+ # likelihood列のみ抽出
35
+ df_likelihood = self.extract_likelihood(df_data)
36
+ # likelihood列を除去したデータ
37
+ df_no_likelihood = self.remove_likelihood(df_data)
38
+ # 付属肢名の置換(左側の名前を mapping で変更)
39
+ df_renamed = self.rename_bodyparts(df_no_likelihood)
40
+ return df_renamed, df_likelihood
41
 
42
+ def remove_likelihood(self, df):
43
+ # 列名が "bodypart|likelihood" となっている列を除外
44
+ new_cols = [col for col in df.columns if col.split("|")[1] != "likelihood"]
45
+ return df.select(new_cols)
46
 
47
+ def rename_bodyparts(self, df):
48
+ cols = df.columns
49
+ current_names = []
50
+ for col in cols:
51
+ bp = col.split("|")[0]
52
+ if bp not in current_names:
53
+ current_names.append(bp)
54
+ if len(self.bodypart_names) != len(current_names):
55
+ raise ValueError("The length of bodypart_names must be equal to the number of bodyparts.")
56
+ mapping = dict(zip(current_names, self.bodypart_names))
57
+ new_cols = {col: f"{mapping[col.split('|')[0]]}|{col.split('|')[1]}" for col in cols}
58
+ return df.rename(new_cols)
59
 
60
+ def extract_likelihood(self, df):
61
+ # likelihood列のみを抽出
62
+ likelihood_cols = [col for col in df.columns if col.split("|")[1] == "likelihood"]
63
+ df_likelihood = df.select(likelihood_cols)
64
+ current_names = []
65
+ for col in likelihood_cols:
66
+ bp = col.split("|")[0]
67
+ if bp not in current_names:
68
+ current_names.append(bp)
69
+ if len(self.bodypart_names) != len(current_names):
70
+ raise ValueError("The length of bodypart_names must be equal to the number of bodyparts.")
71
+ mapping = dict(zip(current_names, self.bodypart_names))
72
+ new_cols = {col: f"{mapping[col.split('|')[0]]}|{col.split('|')[1]}" for col in likelihood_cols}
73
+ return df_likelihood.rename(new_cols)
74
 
75
+ def get_bodyparts(self, df):
76
+ bodyparts = []
77
+ for col in df.columns:
78
+ bp = col.split("|")[0]
79
+ if bp not in bodyparts:
80
+ bodyparts.append(bp)
81
+ return bodyparts
82
 
83
+ def plot_scatter(self, df):
84
+ image_paths = []
85
+ bodyparts = self.get_bodyparts(df)
86
+ app = QApplication.instance()
87
+ if app is None:
88
+ app = QApplication([])
89
 
90
+ # 個別の散布図を作成
91
+ for i, bodypart in enumerate(bodyparts):
92
+ try:
93
+ x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
94
+ y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
95
+ except Exception as e:
96
+ continue
97
+ pw = pg.PlotWidget(title=f'トラッキングの座標({bodypart})')
98
+ pw.setLabel('bottom', 'X Coordinate(pixel)')
99
+ pw.setLabel('left', 'Y Coordinate(pixel)')
100
+ pw.setXRange(0, self.x_max)
101
+ pw.setYRange(0, self.y_max)
102
+ pw.invertY(True)
103
+ color = pg.intColor(i, len(bodyparts))
104
+ # 散布図アイテムの追加
105
+ scatter = pg.ScatterPlotItem(x=x, y=y, pen=pg.mkPen(color=color), symbol='o', brush=color)
106
+ pw.addItem(scatter)
107
+ # 始点を黒丸でハイライトし、"Start"テキストを追加
108
+ if len(x) > 0:
109
+ scatter_start = pg.ScatterPlotItem(x=[x[0]], y=[y[0]], pen=pg.mkPen(color='k'), symbol='o', size=10, brush='k')
110
+ pw.addItem(scatter_start)
111
+ text = pg.TextItem("Start", color='k')
112
+ text.setPos(x[0], y[0])
113
+ pw.addItem(text)
114
+ # PNGにエクスポート
115
+ exporter = ImageExporter(pw.plotItem)
116
+ filename = os.path.join(self.output_folder, f"{bodypart}.png")
117
+ exporter.export(filename)
118
+ image_paths.append(filename)
119
 
120
+ # 全付属肢の散布図を作成
121
+ pw_all = pg.PlotWidget(title='トラッキングの座標(全付属肢)')
122
+ pw_all.setLabel('bottom', 'X Coordinate(pixel)')
123
+ pw_all.setLabel('left', 'Y Coordinate(pixel)')
124
+ pw_all.setXRange(0, self.x_max)
125
+ pw_all.setYRange(0, self.y_max)
126
+ pw_all.invertY(True)
127
+ for i, bodypart in enumerate(bodyparts):
128
+ try:
129
+ x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
130
+ y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
131
+ except Exception as e:
132
+ continue
133
+ color = pg.intColor(i, len(bodyparts))
134
+ scatter = pg.ScatterPlotItem(x=x, y=y, pen=pg.mkPen(color=color), symbol='o', brush=color)
135
+ pw_all.addItem(scatter)
136
+ exporter_all = ImageExporter(pw_all.plotItem)
137
+ filename_all = os.path.join(self.output_folder, "all_plot.png")
138
+ exporter_all.export(filename_all)
139
+ image_paths.append(filename_all)
140
+ return image_paths
141
 
142
+ def plot_trajectories(self, df):
143
+ image_paths = []
144
+ bodyparts = self.get_bodyparts(df)
145
+ app = QApplication.instance()
146
+ if app is None:
147
+ app = QApplication([])
148
 
149
+ # 個別の軌跡図を作成
150
+ for i, bodypart in enumerate(bodyparts):
151
+ try:
152
+ x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
153
+ y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
154
+ except Exception as e:
155
+ continue
156
+ pw = pg.PlotWidget(title=f'トラッキングの座標({bodypart})')
157
+ pw.setLabel('bottom', 'Frames')
158
+ pw.setLabel('left', 'Coordinate(pixel)')
159
+ pen_x = pg.mkPen(color=pg.intColor(i, len(bodyparts)), style=QtCore.Qt.PenStyle.DashLine)
160
+ pen_y = pg.mkPen(color=pg.intColor(i, len(bodyparts)))
161
+ pw.plot(x, pen=pen_x, name=f"{bodypart}(x座標)")
162
+ pw.plot(y, pen=pen_y, name=f"{bodypart}(y座標)")
163
+ exporter = ImageExporter(pw.plotItem)
164
+ filename = os.path.join(self.output_folder, f"{bodypart}_trajectories.png")
165
+ exporter.export(filename)
166
+ image_paths.append(filename)
167
 
168
+ # 全付属肢の軌跡図を作成
169
+ pw_all = pg.PlotWidget(title='トラッキングの座標(全付属肢)')
170
+ pw_all.setLabel('bottom', 'Frames')
171
+ pw_all.setLabel('left', 'Coordinate(pixel)')
172
+ for i, bodypart in enumerate(bodyparts):
173
+ try:
174
+ x = np.array(df[f"{bodypart}|x"].to_list(), dtype=float)
175
+ y = np.array(df[f"{bodypart}|y"].to_list(), dtype=float)
176
+ except Exception as e:
177
+ continue
178
+ pen_x = pg.mkPen(color=pg.intColor(i, len(bodyparts)), style=QtCore.Qt.PenStyle.DashLine)
179
+ pen_y = pg.mkPen(color=pg.intColor(i, len(bodyparts)))
180
+ pw_all.plot(x, pen=pen_x, name=f"{bodypart}(x座標)")
181
+ pw_all.plot(y, pen=pen_y, name=f"{bodypart}(y座標)")
182
+ exporter_all = ImageExporter(pw_all.plotItem)
183
+ filename_all = os.path.join(self.output_folder, "all_trajectories.png")
184
+ exporter_all.export(filename_all)
185
+ image_paths.append(filename_all)
186
+ return image_paths
187
 
188
+ def plot_likelihood(self, df_likelihood):
189
+ image_paths = []
190
+ bodyparts = self.get_bodyparts(df_likelihood)
191
+ app = QApplication.instance()
192
+ if app is None:
193
+ app = QApplication([])
194
 
195
+ # 付属肢ごとの尤度グラフを作成
196
+ for i, bodypart in enumerate(bodyparts):
197
+ try:
198
+ likelihood = np.array(df_likelihood[f"{bodypart}|likelihood"].to_list(), dtype=float)
199
+ except Exception as e:
200
+ continue
201
+ pw = pg.PlotWidget(title=f'フレーム別の尤度 ({bodypart})')
202
+ pw.setLabel('bottom', 'Frames')
203
+ pw.setLabel('left', '尤度')
204
+ pw.setYRange(0, 1.0)
205
+ color = pg.intColor(i, len(bodyparts))
206
+ pw.plot(likelihood, pen=pg.mkPen(color=color), name=bodypart)
207
+ exporter = ImageExporter(pw.plotItem)
208
+ filename = os.path.join(self.output_folder, f"{bodypart}_likelihood.png")
209
+ exporter.export(filename)
210
+ image_paths.append(filename)
211
 
212
+ # 全付属肢の尤度グラフを作成
213
+ pw_all = pg.PlotWidget(title='フレーム別の尤度 (全付属肢)')
214
+ pw_all.setLabel('bottom', 'Frames')
215
+ pw_all.setLabel('left', '尤度')
216
+ pw_all.setYRange(0, 1.0)
217
+ for i, bodypart in enumerate(bodyparts):
218
+ try:
219
+ likelihood = np.array(df_likelihood[f"{bodypart}|likelihood"].to_list(), dtype=float)
220
+ except Exception as e:
221
+ continue
222
+ color = pg.intColor(i, len(bodyparts))
223
+ pw_all.plot(likelihood, pen=pg.mkPen(color=color), name=bodypart)
224
+ exporter_all = ImageExporter(pw_all.plotItem)
225
+ filename_all = os.path.join(self.output_folder, "likelihood_plot.png")
226
+ exporter_all.export(filename_all)
227
+ image_paths.append(filename_all)
228
+ return image_paths
229
 
230
+ class GradioInterface:
231
+ def __init__(self):
232
+ self.interface = gr.Interface(
233
+ fn=self.process_and_plot,
234
+ inputs=[
235
+ gr.File(label="CSVファイルをドラッグ&ドロップ"),
236
+ gr.Textbox(
237
+ label="付属肢の名前(カンマ区切り)",
238
+ value="指節1, 指節2, 指節3, 指節4, 指節5, 指節6, 指節7, 指節8, 指節9, 指節10, 指節11, 指節12, 指節13, 指節14, 触角(左), 触角(右), 頭部, 腹尾節"
239
+ ),
240
+ gr.Number(label="X軸の最大値", value=1920),
241
+ gr.Number(label="Y軸の最大値", value=1080),
242
+ gr.CheckboxGroup(
243
+ label="プロットするグラフを選択",
244
+ choices=["散布図", "軌跡図", "尤度グラフ"],
245
+ value=["散布図", "軌跡図", "尤度グラフ"],
246
+ type="value"
247
+ )
248
+ ],
249
+ outputs=[
250
+ gr.Gallery(label="散布図"),
251
+ gr.File(label="ZIPダウンロード")
252
+ ],
253
+ title="DeepLabCutグラフ出力ツール",
254
+ description="CSVファイルからグラフを作成します。"
255
+ )
256
 
257
+ def process_and_plot(self, file, bodypart_names, x_max, y_max, graph_choices):
258
+ processor = DataProcessor(bodypart_names, x_max, y_max)
259
+ df, df_likelihood = processor.process_csv(file.name)
260
 
261
+ all_image_paths = []
262
+ if "散布図" in graph_choices:
263
+ all_image_paths += processor.plot_scatter(df)
264
+ if "軌跡図" in graph_choices:
265
+ all_image_paths += processor.plot_trajectories(df)
266
+ if "尤度グラフ" in graph_choices:
267
+ all_image_paths += processor.plot_likelihood(df_likelihood)
268
 
269
+ shutil.make_archive(processor.output_folder, 'zip', processor.output_folder)
270
+ return all_image_paths, processor.output_folder + '.zip'
271
 
272
+ def launch(self):
273
+ self.interface.launch()
274
 
275
 
276
+ if __name__ == "__main__":
277
+ gradio_app = GradioInterface()
278
+ gradio_app.launch()