ryo2 commited on
Commit
11b7d82
·
verified ·
1 Parent(s): 80d364a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -37
app.py CHANGED
@@ -10,63 +10,126 @@ import numpy as np
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
12
 
13
- def all_likelihood_plot(csv_file_name, tmpdir):
14
- df = pd.read_csv(csv_file_name, header=[1, 2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  df = df.drop(df.columns[[0]], axis=1)
16
- columns = df.columns.droplevel(1)
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # 重複を削除
19
- columns = columns.drop_duplicates()
20
- likelihood = [df[x]["likelihood"] for x in columns]
21
- a = pd.DataFrame(likelihood, index=columns).T
22
 
23
- #平均値を求める
24
- point_average = a.mean()
25
-
26
- # CSVから自動的に取得した付属肢名を使用
27
- parts = columns.tolist()
 
 
 
 
 
 
 
 
 
 
28
 
29
  # カラーマップの設定
30
- cmap = plt.get_cmap('rainbow')
31
- # バイオリン図のプロット
32
- sns.set(style="whitegrid",font="IPAexGothic")
33
- fig, ax = plt.subplots()
34
 
35
- # データをバイオリンプロット描画
36
- sns.violinplot(data=a, palette=[cmap(i)
37
- for i in np.linspace(0, 1, len(columns))], ax=ax,inner=None)
38
 
39
- # 横軸のラベルを重ならないように
40
  plt.xticks(rotation=65)
41
 
42
- ax.set_title('付属肢別尤度')
43
- ax.set_xlabel('付属肢')
44
- ax.set_ylabel('尤度')
 
45
 
46
- #それぞれ要素の平均値をプロット
47
- plt.scatter(x=parts, y=point_average, color='black', marker='x')
48
 
49
- # 最大値を1に
50
  plt.ylim(0, 1)
51
 
52
- #ラベルがはみ出ないように
53
  plt.tight_layout()
54
 
55
- # グラフを表示
56
- plt.savefig(f"likelihood.png", dpi=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def main(csv_file):
 
 
 
 
 
 
 
 
 
 
60
  with tempfile.TemporaryDirectory(dir=".") as tmpdir:
61
- all_likelihood_plot(csv_file, tmpdir)
62
- return f"likelihood.png"
63
 
64
 
 
65
  iface = gr.Interface(
66
- fn=main,
67
- inputs="file",
68
- outputs="image",
69
- title="尤度のグラフを作します。",
70
- description="CSVファイルから自動的に付属肢を抽出してバイオリンプロットを作成します。"
71
  )
72
- iface.launch()
 
 
 
10
  import matplotlib.pyplot as plt
11
  import seaborn as sns
12
 
13
+
14
+ def load_and_process_csv(csv_file_path):
15
+ """
16
+ CSVファイルを読み込み、尤度データと付属肢名を抽出する
17
+
18
+ Args:
19
+ csv_file_path: CSVファイルのパス
20
+
21
+ Returns:
22
+ likelihood_df: 尤度データのDataFrame
23
+ bodyparts: 付属肢の名前リスト
24
+ point_average: 各付属肢の平均尤度
25
+ """
26
+ # DeepLabCutのCSVはマルチヘッダー形式
27
+ df = pd.read_csv(csv_file_path, header=[1, 2])
28
+ # 最初の列(通常はフレーム番号)を削除
29
  df = df.drop(df.columns[[0]], axis=1)
30
+ # 各カラムの第2レベル(x, y, likelihood)を削除し、第1レベル(付属肢名)のみを取得
31
+ bodyparts = df.columns.droplevel(1)
32
+ # 重複を削除(x, y, likelihoodが各付属肢にあるため)
33
+ bodyparts = bodyparts.drop_duplicates()
34
+
35
+ # 各付属肢のlikelihood列を抽出してDataFrameに変換
36
+ likelihood_list = [df[part]["likelihood"] for part in bodyparts]
37
+ likelihood_df = pd.DataFrame(likelihood_list, index=bodyparts).T
38
+
39
+ # 各付属肢の平均尤度を計算
40
+ point_average = likelihood_df.mean()
41
+
42
+ return likelihood_df, bodyparts.tolist(), point_average
43
 
 
 
 
 
44
 
45
+ def create_violin_plot(likelihood_df, bodyparts, point_average, output_path):
46
+ """
47
+ 尤度データからバイオリンプロットを作成して保存する
48
+
49
+ Args:
50
+ likelihood_df: 尤度データのDataFrame
51
+ bodyparts: 付属肢の名前リスト
52
+ point_average: 各付属肢の平均尤度
53
+ output_path: 出力画像の保存パス
54
+ """
55
+ # スタイルとフォントの設定
56
+ sns.set(style="whitegrid", font="IPAexGothic")
57
+
58
+ # グラフの作成
59
+ fig, ax = plt.subplots(figsize=(12, 8))
60
 
61
  # カラーマップの設定
62
+ cmap = plt.get_cmap("rainbow")
63
+ colors = [cmap(i) for i in np.linspace(0, 1, len(bodyparts))]
 
 
64
 
65
+ # バイオリンプロット描画
66
+ sns.violinplot(data=likelihood_df, palette=colors, ax=ax, inner=None)
 
67
 
68
+ # 横軸のラベルを回転して重なりを防止
69
  plt.xticks(rotation=65)
70
 
71
+ # グラフタイトルと軸ラベルの設定
72
+ ax.set_title("付属肢別の尤度")
73
+ ax.set_xlabel("付属肢")
74
+ ax.set_ylabel("尤度")
75
 
76
+ # 各付属肢の平均値をXマーカーでプロット
77
+ plt.scatter(x=range(len(bodyparts)), y=point_average, color="black", marker="x")
78
 
79
+ # 尤度の範囲0-1に設定
80
  plt.ylim(0, 1)
81
 
82
+ # ラベルがはみ出ないようにレイアウトを調整
83
  plt.tight_layout()
84
 
85
+ # グラフをファイルに保存
86
+ plt.savefig(output_path, dpi=300)
87
+ plt.close()
88
+
89
+
90
+ def process_and_plot(csv_file_path):
91
+ """
92
+ CSVファイルを処理してバイオリンプロットを作成する
93
+
94
+ Args:
95
+ csv_file_path: CSVファイルのパス
96
+
97
+ Returns:
98
+ output_path: 作成された画像ファイルのパス
99
+ """
100
+ # CSVデータを読み込み処理
101
+ likelihood_df, bodyparts, point_average = load_and_process_csv(csv_file_path)
102
+
103
+ # バイオリンプロットを作成して保存
104
+ output_path = "likelihood.png"
105
+ create_violin_plot(likelihood_df, bodyparts, point_average, output_path)
106
+
107
+ return output_path
108
 
109
 
110
  def main(csv_file):
111
+ """
112
+ Gradioインターフェース用のメイン関数
113
+
114
+ Args:
115
+ csv_file: Gradioからのファイルオブジェクト
116
+
117
+ Returns:
118
+ 画像ファイルのパス
119
+ """
120
+ # 一時ディレクトリを作成して処理
121
  with tempfile.TemporaryDirectory(dir=".") as tmpdir:
122
+ return process_and_plot(csv_file)
 
123
 
124
 
125
+ # Gradioインターフェースの設定と起動
126
  iface = gr.Interface(
127
+ fn=main,
128
+ inputs="file",
129
+ outputs="image",
130
+ title="DeepLabCut 尤度バイオリンプロット生ツール",
131
+ description="CSVファイルから自動的に付属肢を抽出して尤度のバイオリンプロットを作成します。",
132
  )
133
+
134
+ if __name__ == "__main__":
135
+ iface.launch()