ryo2 commited on
Commit
dc16bdc
·
verified ·
1 Parent(s): 30397cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -34
app.py CHANGED
@@ -8,45 +8,34 @@ import japanize_matplotlib
8
 
9
 
10
  class DataProcessor:
11
- def __init__(self, bodypart_names, x_max, y_max):
12
- self.bodypart_names = bodypart_names.split(',')
13
  self.x_max = x_max
14
  self.y_max = y_max
15
  self.output_folder = 'output_plots'
 
16
 
17
  def process_csv(self, file_path):
18
  df = pd.read_csv(file_path, header=[1, 2])
 
 
 
 
 
 
 
19
  df_likelihood = self.extract_likelihood(df)
20
  df = self.remove_first_column_and_likelihood(df)
21
- df = self.rename_bodyparts(df)
22
- return df, df_likelihood
23
 
24
  def remove_first_column_and_likelihood(self, df):
25
  df = df.drop(df.columns[0], axis=1)
26
  df = df[df.columns.drop(list(df.filter(regex='likelihood')))]
27
  return df
28
 
29
- def rename_bodyparts(self, df):
30
- current_names = df.columns.get_level_values(0).unique()
31
- if len(self.bodypart_names) != len(current_names):
32
- raise ValueError(
33
- "The length of bodypart_names must be equal to the number of bodyparts.")
34
- mapping = dict(zip(current_names, self.bodypart_names))
35
- new_columns = [(mapping[col[0]], col[1]) if col[0]
36
- in mapping else col for col in df.columns]
37
- df.columns = pd.MultiIndex.from_tuples(new_columns)
38
- return df
39
-
40
  def extract_likelihood(self, df):
41
  # likelihood列のみを抽出する
42
  df = df[df.columns[df.columns.get_level_values(1) == 'likelihood']]
43
  df.drop(df.columns[0], axis=1)
44
- current_names = df.columns.get_level_values(0).unique()
45
- mapping = dict(zip(current_names, self.bodypart_names))
46
- new_columns = [(mapping[col[0]], col[1]) if col[0]
47
- in mapping else col for col in df.columns]
48
- df.columns = pd.MultiIndex.from_tuples(new_columns)
49
-
50
  return df
51
 
52
  def plot_scatter(self, df):
@@ -172,16 +161,12 @@ class DataProcessor:
172
  image_paths.append(f'{self.output_folder}/likelihood_plot.png')
173
  return image_paths
174
 
175
- # 以下のGradioInterfaceクラスとメイン実行部分は変更なし
176
-
177
  class GradioInterface:
178
  def __init__(self):
179
  self.interface = gr.Interface(
180
  fn=self.process_and_plot,
181
  inputs=[
182
  gr.File(label="CSVファイルをドラッグ&ドロップ"),
183
- gr.Textbox(label="付属肢の名前(カンマ区切り)",
184
- value="指節1, 指節2, 指節3, 指節4, 指節5, 指節6, 指節7, 指節8, 指節9,指節10, 指節11, 指節12, 指節13, 指節14, 触角(左), 触角(右), 頭部, 腹尾節"),
185
  gr.Number(label="X軸の最大値", value=1920),
186
  gr.Number(label="Y軸の最大値", value=1080),
187
  gr.CheckboxGroup(
@@ -192,16 +177,17 @@ class GradioInterface:
192
  )
193
  ],
194
  outputs=[
195
- gr.Gallery(label="散布図"),
196
- gr.File(label="ZIPダウンロード")
 
197
  ],
198
  title="DeepLabCutグラフ出力ツール",
199
- description="CSVファイルからグラフを作成します。"
200
  )
201
 
202
- def process_and_plot(self, file, bodypart_names, x_max, y_max, graph_choices):
203
- processor = DataProcessor(bodypart_names, x_max, y_max)
204
- df, df_likelihood = processor.process_csv(file.name)
205
 
206
  all_image_paths = []
207
  if "散布図" in graph_choices:
@@ -211,9 +197,11 @@ class GradioInterface:
211
  if "尤度グラフ" in graph_choices:
212
  all_image_paths += processor.plot_likelihood(df_likelihood)
213
 
214
- shutil.make_archive(processor.output_folder,
215
- 'zip', processor.output_folder)
216
- return all_image_paths, processor.output_folder + '.zip'
 
 
217
 
218
  def launch(self):
219
  self.interface.launch()
 
8
 
9
 
10
  class DataProcessor:
11
+ def __init__(self, x_max, y_max):
 
12
  self.x_max = x_max
13
  self.y_max = y_max
14
  self.output_folder = 'output_plots'
15
+ self.bodypart_names = None # 初期化時にはNoneに設定
16
 
17
  def process_csv(self, file_path):
18
  df = pd.read_csv(file_path, header=[1, 2])
19
+
20
+ # CSVから自動的に付属肢名を抽出
21
+ self.bodypart_names = df.columns.get_level_values(0).unique().tolist()
22
+ # 最初の列(通常はscorerなど)を除外
23
+ if len(self.bodypart_names) > 0:
24
+ self.bodypart_names = self.bodypart_names[1:]
25
+
26
  df_likelihood = self.extract_likelihood(df)
27
  df = self.remove_first_column_and_likelihood(df)
28
+ return df, df_likelihood, self.bodypart_names # 抽出した付属肢名も返す
 
29
 
30
  def remove_first_column_and_likelihood(self, df):
31
  df = df.drop(df.columns[0], axis=1)
32
  df = df[df.columns.drop(list(df.filter(regex='likelihood')))]
33
  return df
34
 
 
 
 
 
 
 
 
 
 
 
 
35
  def extract_likelihood(self, df):
36
  # likelihood列のみを抽出する
37
  df = df[df.columns[df.columns.get_level_values(1) == 'likelihood']]
38
  df.drop(df.columns[0], axis=1)
 
 
 
 
 
 
39
  return df
40
 
41
  def plot_scatter(self, df):
 
161
  image_paths.append(f'{self.output_folder}/likelihood_plot.png')
162
  return image_paths
163
 
 
 
164
  class GradioInterface:
165
  def __init__(self):
166
  self.interface = gr.Interface(
167
  fn=self.process_and_plot,
168
  inputs=[
169
  gr.File(label="CSVファイルをドラッグ&ドロップ"),
 
 
170
  gr.Number(label="X軸の最大値", value=1920),
171
  gr.Number(label="Y軸の最大値", value=1080),
172
  gr.CheckboxGroup(
 
177
  )
178
  ],
179
  outputs=[
180
+ gr.Gallery(label="グラフ"),
181
+ gr.File(label="ZIPダウンロード"),
182
+ gr.Textbox(label="検出された付属肢") # 検出された付属肢を表示するための出力を追加
183
  ],
184
  title="DeepLabCutグラフ出力ツール",
185
+ description="CSVファイルからグラフを作成します。付属肢はCSVファイルから自動的に抽出されます。"
186
  )
187
 
188
+ def process_and_plot(self, file, x_max, y_max, graph_choices):
189
+ processor = DataProcessor(x_max, y_max)
190
+ df, df_likelihood, bodypart_names = processor.process_csv(file.name)
191
 
192
  all_image_paths = []
193
  if "散布図" in graph_choices:
 
197
  if "尤度グラフ" in graph_choices:
198
  all_image_paths += processor.plot_likelihood(df_likelihood)
199
 
200
+ # 付属肢の名前を表示用に結合
201
+ bodyparts_text = ", ".join(bodypart_names)
202
+
203
+ shutil.make_archive(processor.output_folder, 'zip', processor.output_folder)
204
+ return all_image_paths, processor.output_folder + '.zip', bodyparts_text
205
 
206
  def launch(self):
207
  self.interface.launch()