siglip2-ecg-multilabel / predict_multilabel.py
longisland3's picture
Upload folder using huggingface_hub
aeac162 verified
import os
import json
import torch
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import numpy as np
import xml.etree.ElementTree as ET
import time
import random
import argparse
import re
# --- 設定 ---
# チェックポイントのパス
ckpt_dir = "siglip2-ecg-multilabel/checkpoint-750"
# ラベル一覧
with open('all_labels.json', 'r', encoding='utf-8') as f:
all_labels = json.load(f)
id2label = {i: lab for i, lab in enumerate(all_labels)}
# モデル・プロセッサのロード
processor = AutoImageProcessor.from_pretrained(ckpt_dir)
model = SiglipForImageClassification.from_pretrained(ckpt_dir)
model.eval()
# 画像前処理
image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size['height']
transforms = Compose([
Resize((size, size)),
ToTensor(),
Normalize(mean=image_mean, std=image_std)
])
# 閾値設定の読み込み
threshold_file = 'threshold_analysis_results.json'
try:
with open(threshold_file, 'r', encoding='utf-8') as f:
threshold_data = json.load(f)
# ラベルごとの閾値をマッピング
label_thresholds = {}
for item in threshold_data['results']:
label_thresholds[item['label']] = item['best_threshold']
default_threshold = threshold_data['average_threshold']
print(f"閾値設定を {threshold_file} から読み込みました")
except FileNotFoundError:
print(f"警告: {threshold_file} が見つかりません。デフォルト閾値を使用します。")
label_thresholds = {}
default_threshold = 0.5
# グループ別の閾値設定
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
diagnosis_keywords = ['洞調律', '心房細動', '心筋梗塞', '**', '心室性', '心房性', 'ブロック', '軸偏位', '左室肥大', '右室肥大']
def get_group_threshold_multiplier(label):
"""ラベルのグループに基づいて閾値の乗数を返す"""
if label in leads:
return 0.8 # 誘導は信頼性が高いので少し下げる
# 診断関連のキーワードを含むラベルは閾値を上げる
for keyword in diagnosis_keywords:
if keyword in label:
return 3.0 # 診断は閾値を大幅に上げる
# ミネソタコード(数字とハイフンのパターン)
if re.match(r'^\d+-\d+(-\d+)?$', label):
return 2.5 # ミネソタコードも閾値を上げる
return 2.0 # その他のラベル
# --- 推論関数 ---
def predict_image(image_path, threshold=0.5, max_predictions=15, min_confidence=0.05):
"""画像を予測して、閾値以上のラベルを返す"""
start_time = time.time()
image = Image.open(image_path).convert("RGB")
pixel = transforms(image).unsqueeze(0)
with torch.no_grad():
outputs = model(pixel_values=pixel)
logits = outputs.logits.cpu().numpy()[0]
probs = 1 / (1 + np.exp(-logits))
# 各ラベルに対して最適な閾値を適用
preds = []
for i, label in enumerate(all_labels):
# ラベルごとの最適な閾値を使用(なければデフォルト値)
base_threshold = label_thresholds.get(label, default_threshold)
# グループに基づいて閾値を調整
adjusted_threshold = base_threshold * get_group_threshold_multiplier(label)
if probs[i] >= adjusted_threshold and probs[i] >= min_confidence:
preds.append((label, probs[i], adjusted_threshold))
# 確率で降順ソート
preds.sort(key=lambda x: x[1], reverse=True)
# 予測数を制限
preds = preds[:max_predictions]
inference_time = time.time() - start_time
return [(label, prob) for label, prob, _ in preds], inference_time
# --- 適当な画像でテスト ---
# 画像ディレクトリ例
image_root = "/home/nagashimadaichi/dev/vectorize-ecg/data/images"
xml_root = "/home/nagashimadaichi/dev/vectorize-ecg/data/xml"
# --- グラウンドトゥルース取得用関数 ---
def parse_ecg_xml(xml_path):
try:
tree = ET.parse(xml_path)
root = tree.getroot()
ns = {'v3': 'urn:hl7-org:v3'}
diagnoses = []
for interp in root.findall('.//v3:interpretationResult', ns):
text = interp.find('v3:text', ns)
if text is not None and text.text and text.text.strip():
diagnoses.append(text.text.strip())
codes = []
for group in root.findall('.//v3:justifiedDecisionGroup', ns):
code_elem = group.find('v3:interpretationCode', ns)
if code_elem is not None and code_elem.get('displayName') == 'ミネソタコード':
for val in group.findall('.//v3:interpretationResult/v3:value', ns):
if val.text and val.text.strip():
codes.append(val.text.strip())
return diagnoses, codes
except:
return [], []
def get_gt_labels(image_path, xml_root_dir):
fn = os.path.basename(image_path)
parts = fn.split('_')
base, date = parts[0], parts[1]
lead = parts[-1].split('.')[0]
xml_file = f"{base}_{date}.xml"
for root_dir, _, files in os.walk(xml_root_dir):
if xml_file in files:
xml_path = os.path.join(root_dir, xml_file)
diagnoses, codes = parse_ecg_xml(xml_path)
return [lead] + diagnoses + codes
return [lead]
def main():
# テスト画像をランダムに5つ選択
test_dir = args.test_dir if args.test_dir else "/home/nagashimadaichi/dev/vectorize-ecg/data/images"
xml_dir = args.xml_dir if args.xml_dir else "/home/nagashimadaichi/dev/vectorize-ecg/data/xml"
all_images = []
for root, _, files in os.walk(test_dir):
for f in files:
if f.endswith('.png'):
all_images.append(os.path.join(root, f))
if not all_images:
print(f"Error: 画像ファイルが {test_dir} に見つかりません")
return
# ランダムに画像を選択
random.shuffle(all_images)
test_images = all_images[:5]
# 精度計算用
correct_exact = 0
total_f1 = 0
total_precision = 0
total_recall = 0
total_time = 0
total_images = len(test_images)
for img_path in test_images:
# 正解ラベルを取得
gt_labels = get_gt_labels(img_path, xml_dir)
# 予測(最大15個のラベルに制限し、最低確率は0.05に設定)
labels, elapsed = predict_image(img_path, max_predictions=15, min_confidence=0.05)
total_time += elapsed
# 完全一致判定
if set(gt_labels) == set([label for label, _ in labels]):
correct_exact += 1
# 一致率計算(順番を考慮せず、集合としての一致率)
gt_set = set(gt_labels)
pred_set = set([label for label, _ in labels])
intersection = gt_set.intersection(pred_set)
union = gt_set.union(pred_set)
# 適合率、再現率、F1スコアを計算
precision = len(intersection) / len(pred_set) if pred_set else 1.0
recall = len(intersection) / len(gt_set) if gt_set else 1.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
total_f1 += f1
total_precision += precision
total_recall += recall
print("\n" + "=" * 50)
print(f"画像: {os.path.basename(img_path)}")
print(f"- 推論時間: {elapsed:.4f}秒")
print("\n予測結果(確率 / 使用閾値):")
for label, prob in labels:
base_threshold = label_thresholds.get(label, default_threshold)
adjusted_threshold = base_threshold * get_group_threshold_multiplier(label)
print(f"- {label:<30} : {prob:.4f} / {adjusted_threshold:.4f}")
print("\n正解ラベル:")
for label in gt_labels:
if label in [l for l, _ in labels]:
print(f"- {label:<30} [正解]")
else:
print(f"- {label:<30} [未検出]")
print("\n不正解の予測:")
for label, prob in labels:
if label not in gt_set:
print(f"- {label:<30} : {prob:.4f} [過検出]")
print(f"\n一致率: {len(intersection)}/{len(union)} = {len(intersection)/len(union):.4f}")
print(f"適合率: {precision:.4f}, 再現率: {recall:.4f}, F1スコア: {f1:.4f}")
# 全体の結果
print("\n" + "=" * 50)
print(f"全体の結果({total_images}枚):")
print(f"完全一致率: {correct_exact}/{total_images} = {correct_exact/total_images:.4f}")
print(f"平均適合率: {total_precision/total_images:.4f}")
print(f"平均再現率: {total_recall/total_images:.4f}")
print(f"平均F1スコア: {total_f1/total_images:.4f}")
print(f"平均推論時間: {total_time/total_images:.4f}秒")
print(f"推論速度: {total_images/total_time:.2f}枚/秒")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='ECG画像の多ラベル予測')
parser.add_argument('--test_dir', help='テスト画像のあるディレクトリ')
parser.add_argument('--xml_dir', help='XMLファイルのあるディレクトリ')
args = parser.parse_args()
main()