|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from transformers import AutoImageProcessor, SiglipForImageClassification |
|
|
from PIL import Image |
|
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize |
|
|
import numpy as np |
|
|
import time |
|
|
import argparse |
|
|
import random |
|
|
import xml.etree.ElementTree as ET |
|
|
import re |
|
|
|
|
|
|
|
|
ckpt_dir = "siglip2-ecg-multilabel/checkpoint-750" |
|
|
|
|
|
|
|
|
DEFAULT_IMAGE_DIR = "/home/nagashimadaichi/dev/vectorize-ecg/data/images" |
|
|
DEFAULT_XML_DIR = "/home/nagashimadaichi/dev/vectorize-ecg/data/xml" |
|
|
|
|
|
|
|
|
with open('all_labels.json', 'r', encoding='utf-8') as f: |
|
|
all_labels = json.load(f) |
|
|
|
|
|
|
|
|
threshold_file = 'threshold_analysis_results.json' |
|
|
try: |
|
|
with open(threshold_file, 'r', encoding='utf-8') as f: |
|
|
threshold_data = json.load(f) |
|
|
|
|
|
label_thresholds = {} |
|
|
for item in threshold_data['results']: |
|
|
label_thresholds[item['label']] = item['best_threshold'] |
|
|
default_threshold = threshold_data['average_threshold'] |
|
|
print(f"閾値設定を {threshold_file} から読み込みました") |
|
|
print(f"平均閾値: {default_threshold:.4f}") |
|
|
except FileNotFoundError: |
|
|
print(f"警告: {threshold_file} が見つかりません。デフォルト閾値を使用します。") |
|
|
label_thresholds = {} |
|
|
default_threshold = 0.1 |
|
|
|
|
|
|
|
|
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] |
|
|
diagnosis_keywords = ['洞調律', '心房細動', '心筋梗塞', '**', '心室性', '心房性', 'ブロック', '軸偏位'] |
|
|
|
|
|
def get_group_threshold_multiplier(label): |
|
|
"""ラベルのグループに基づいて閾値の乗数を返す""" |
|
|
if label in leads: |
|
|
return 0.2 |
|
|
|
|
|
|
|
|
if label in ['洞調律', '** normal ECG **', '** abnormal ECG **', '心房細動']: |
|
|
return 0.3 |
|
|
|
|
|
|
|
|
for keyword in diagnosis_keywords: |
|
|
if keyword in label: |
|
|
return 0.7 |
|
|
|
|
|
|
|
|
if re.match(r'^\d+-\d+(-\d+)?$', label): |
|
|
return 0.8 |
|
|
|
|
|
return 0.9 |
|
|
|
|
|
|
|
|
print("モデルをロード中...") |
|
|
processor = AutoImageProcessor.from_pretrained(ckpt_dir) |
|
|
model = SiglipForImageClassification.from_pretrained(ckpt_dir) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
image_mean, image_std = processor.image_mean, processor.image_std |
|
|
size = processor.size['height'] |
|
|
transforms = Compose([ |
|
|
Resize((size, size)), |
|
|
ToTensor(), |
|
|
Normalize(mean=image_mean, std=image_std) |
|
|
]) |
|
|
|
|
|
def get_random_image(image_dir=DEFAULT_IMAGE_DIR): |
|
|
"""データディレクトリからランダムに画像を選択""" |
|
|
all_images = [] |
|
|
for root, _, files in os.walk(image_dir): |
|
|
for f in files: |
|
|
if f.endswith('.png'): |
|
|
all_images.append(os.path.join(root, f)) |
|
|
|
|
|
if not all_images: |
|
|
raise FileNotFoundError(f"画像ファイルが {image_dir} に見つかりません") |
|
|
|
|
|
random_image = random.choice(all_images) |
|
|
print(f"ランダムに選択された画像: {random_image}") |
|
|
return random_image |
|
|
|
|
|
def parse_ecg_xml(xml_path): |
|
|
"""XMLファイルから診断情報とミネソタコードを抽出""" |
|
|
try: |
|
|
tree = ET.parse(xml_path) |
|
|
root = tree.getroot() |
|
|
ns = {'v3': 'urn:hl7-org:v3'} |
|
|
diagnoses = [] |
|
|
for interp in root.findall('.//v3:interpretationResult', ns): |
|
|
text = interp.find('v3:text', ns) |
|
|
if text is not None and text.text and text.text.strip(): |
|
|
diagnoses.append(text.text.strip()) |
|
|
codes = [] |
|
|
for group in root.findall('.//v3:justifiedDecisionGroup', ns): |
|
|
code_elem = group.find('v3:interpretationCode', ns) |
|
|
if code_elem is not None and code_elem.get('displayName') == 'ミネソタコード': |
|
|
for val in group.findall('.//v3:interpretationResult/v3:value', ns): |
|
|
if val.text and val.text.strip(): |
|
|
codes.append(val.text.strip()) |
|
|
return diagnoses, codes |
|
|
except Exception as e: |
|
|
print(f"XML解析エラー: {e}") |
|
|
return [], [] |
|
|
|
|
|
def get_gt_labels(image_path, xml_root_dir): |
|
|
"""画像に対応する正解ラベルを取得""" |
|
|
try: |
|
|
fn = os.path.basename(image_path) |
|
|
parts = fn.split('_') |
|
|
base, date = parts[0], parts[1] |
|
|
|
|
|
lead = None |
|
|
lead_match = re.search(r'_(aVR|aVL|aVF|[IV][1-6])\.png$', fn) |
|
|
if lead_match: |
|
|
lead = lead_match.group(1) |
|
|
else: |
|
|
lead = parts[-1].split('.')[0] |
|
|
|
|
|
|
|
|
xml_file = f"{base}_{date}.xml" |
|
|
for root_dir, _, files in os.walk(xml_root_dir): |
|
|
if xml_file in files: |
|
|
xml_path = os.path.join(root_dir, xml_file) |
|
|
diagnoses, codes = parse_ecg_xml(xml_path) |
|
|
labels = [] |
|
|
if lead: |
|
|
labels.append(lead) |
|
|
labels.extend(diagnoses) |
|
|
labels.extend(codes) |
|
|
print(f"正解ラベル数: {len(labels)}個 (誘導1個 + 診断{len(diagnoses)}個 + コード{len(codes)}個)") |
|
|
return labels |
|
|
|
|
|
|
|
|
if lead: |
|
|
return [lead] |
|
|
return ["XMLファイルが見つかりませんでした"] |
|
|
except Exception as e: |
|
|
print(f"ラベル取得エラー: {e}") |
|
|
return ["ラベル取得エラー"] |
|
|
|
|
|
def predict_image(image_path, min_confidence=0.01, max_results=25, use_threshold_file=True): |
|
|
"""画像を予測して、信頼度の高いラベルを返す""" |
|
|
print(f"画像を予測中: {image_path}") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
pixel = transforms(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(pixel_values=pixel) |
|
|
logits = outputs.logits.cpu().numpy()[0] |
|
|
probs = 1 / (1 + np.exp(-logits)) |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
for i, label in enumerate(all_labels): |
|
|
|
|
|
if use_threshold_file: |
|
|
|
|
|
base_threshold = label_thresholds.get(label, default_threshold) |
|
|
|
|
|
adjusted_threshold = base_threshold * get_group_threshold_multiplier(label) |
|
|
threshold = max(adjusted_threshold, min_confidence) |
|
|
else: |
|
|
threshold = min_confidence |
|
|
|
|
|
if probs[i] >= threshold: |
|
|
predictions.append((label, probs[i], threshold)) |
|
|
|
|
|
predictions.sort(key=lambda x: x[1], reverse=True) |
|
|
predictions = predictions[:max_results] |
|
|
|
|
|
inference_time = time.time() - start_time |
|
|
return predictions, inference_time |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='ECG画像の予測(シンプル版)') |
|
|
parser.add_argument('--image_path', help='予測する画像ファイルのパス(指定しない場合はランダムに選択)') |
|
|
parser.add_argument('--image_dir', default=DEFAULT_IMAGE_DIR, help=f'画像ディレクトリ(デフォルト: {DEFAULT_IMAGE_DIR})') |
|
|
parser.add_argument('--xml_dir', default=DEFAULT_XML_DIR, help=f'XMLファイルディレクトリ(デフォルト: {DEFAULT_XML_DIR})') |
|
|
parser.add_argument('--min_confidence', type=float, default=0.01, help='最小信頼度(デフォルト: 0.01)') |
|
|
parser.add_argument('--max_results', type=int, default=25, help='表示する最大予測数(デフォルト: 25)') |
|
|
parser.add_argument('--use_threshold_file', action='store_true', help='threshold_analysis_results.jsonの閾値を使用する') |
|
|
parser.add_argument('--no_threshold_file', dest='use_threshold_file', action='store_false', help='単一の閾値を使用する') |
|
|
parser.set_defaults(use_threshold_file=True) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.image_path: |
|
|
|
|
|
image_path = args.image_path |
|
|
if not os.path.exists(image_path): |
|
|
print(f"エラー: 画像ファイル '{image_path}' が見つかりません") |
|
|
return |
|
|
else: |
|
|
|
|
|
try: |
|
|
image_path = get_random_image(args.image_dir) |
|
|
except FileNotFoundError as e: |
|
|
print(f"エラー: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
gt_labels = get_gt_labels(image_path, args.xml_dir) |
|
|
|
|
|
|
|
|
predictions, inference_time = predict_image( |
|
|
image_path, |
|
|
min_confidence=args.min_confidence, |
|
|
max_results=args.max_results, |
|
|
use_threshold_file=args.use_threshold_file |
|
|
) |
|
|
|
|
|
|
|
|
pred_labels_set = set([label for label, _, _ in predictions]) |
|
|
gt_labels_set = set(gt_labels) |
|
|
|
|
|
|
|
|
intersection = gt_labels_set.intersection(pred_labels_set) |
|
|
precision = len(intersection) / len(pred_labels_set) if pred_labels_set else 0 |
|
|
recall = len(intersection) / len(gt_labels_set) if gt_labels_set else 0 |
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
|
|
|
|
|
|
print("\n予測結果:") |
|
|
print(f"画像: {os.path.basename(image_path)}") |
|
|
print(f"推論時間: {inference_time:.4f}秒\n") |
|
|
|
|
|
|
|
|
print("正解ラベル一覧:") |
|
|
for i, label in enumerate(gt_labels, 1): |
|
|
print(f"{i}. {label}") |
|
|
|
|
|
print("\n" + "-" * 75) |
|
|
|
|
|
|
|
|
print(f"{'予測ラベル':<45} {'信頼度':>10} {'閾値':>10} {'正解':>5}") |
|
|
print("-" * 75) |
|
|
|
|
|
for label, prob, threshold in predictions: |
|
|
|
|
|
confidence_mark = '' |
|
|
if prob >= 0.9: |
|
|
confidence_mark = '🟢' |
|
|
elif prob >= 0.7: |
|
|
confidence_mark = '🟡' |
|
|
elif prob >= 0.4: |
|
|
confidence_mark = '🟠' |
|
|
else: |
|
|
confidence_mark = '🔴' |
|
|
|
|
|
|
|
|
match_mark = '✅' if label in gt_labels_set else '❌' |
|
|
|
|
|
print(f"{label:<45} {prob:.4f} {confidence_mark} {threshold:.4f} {match_mark}") |
|
|
|
|
|
|
|
|
missed_labels = gt_labels_set - pred_labels_set |
|
|
if missed_labels: |
|
|
print("\n検出できなかった正解ラベル:") |
|
|
for label in missed_labels: |
|
|
print(f"- {label}") |
|
|
|
|
|
|
|
|
print("\nメトリクス:") |
|
|
print(f"適合率 (Precision): {precision:.4f}") |
|
|
print(f"再現率 (Recall): {recall:.4f}") |
|
|
print(f"F1スコア: {f1:.4f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |