|
|
import os |
|
|
import json |
|
|
import torch |
|
|
from transformers import AutoImageProcessor, SiglipForImageClassification |
|
|
from PIL import Image |
|
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize |
|
|
import numpy as np |
|
|
import xml.etree.ElementTree as ET |
|
|
import time |
|
|
import random |
|
|
import argparse |
|
|
import re |
|
|
|
|
|
|
|
|
|
|
|
ckpt_dir = "siglip2-ecg-multilabel/checkpoint-750" |
|
|
|
|
|
with open('all_labels.json', 'r', encoding='utf-8') as f: |
|
|
all_labels = json.load(f) |
|
|
id2label = {i: lab for i, lab in enumerate(all_labels)} |
|
|
|
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained(ckpt_dir) |
|
|
model = SiglipForImageClassification.from_pretrained(ckpt_dir) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
image_mean, image_std = processor.image_mean, processor.image_std |
|
|
size = processor.size['height'] |
|
|
transforms = Compose([ |
|
|
Resize((size, size)), |
|
|
ToTensor(), |
|
|
Normalize(mean=image_mean, std=image_std) |
|
|
]) |
|
|
|
|
|
|
|
|
threshold_file = 'threshold_analysis_results.json' |
|
|
try: |
|
|
with open(threshold_file, 'r', encoding='utf-8') as f: |
|
|
threshold_data = json.load(f) |
|
|
|
|
|
label_thresholds = {} |
|
|
for item in threshold_data['results']: |
|
|
label_thresholds[item['label']] = item['best_threshold'] |
|
|
default_threshold = threshold_data['average_threshold'] |
|
|
print(f"閾値設定を {threshold_file} から読み込みました") |
|
|
except FileNotFoundError: |
|
|
print(f"警告: {threshold_file} が見つかりません。デフォルト閾値を使用します。") |
|
|
label_thresholds = {} |
|
|
default_threshold = 0.5 |
|
|
|
|
|
|
|
|
leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] |
|
|
diagnosis_keywords = ['洞調律', '心房細動', '心筋梗塞', '**', '心室性', '心房性', 'ブロック', '軸偏位', '左室肥大', '右室肥大'] |
|
|
|
|
|
def get_group_threshold_multiplier(label): |
|
|
"""ラベルのグループに基づいて閾値の乗数を返す""" |
|
|
if label in leads: |
|
|
return 0.8 |
|
|
|
|
|
|
|
|
for keyword in diagnosis_keywords: |
|
|
if keyword in label: |
|
|
return 3.0 |
|
|
|
|
|
|
|
|
if re.match(r'^\d+-\d+(-\d+)?$', label): |
|
|
return 2.5 |
|
|
|
|
|
return 2.0 |
|
|
|
|
|
|
|
|
def predict_image(image_path, threshold=0.5, max_predictions=15, min_confidence=0.05): |
|
|
"""画像を予測して、閾値以上のラベルを返す""" |
|
|
start_time = time.time() |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
pixel = transforms(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(pixel_values=pixel) |
|
|
logits = outputs.logits.cpu().numpy()[0] |
|
|
probs = 1 / (1 + np.exp(-logits)) |
|
|
|
|
|
|
|
|
preds = [] |
|
|
for i, label in enumerate(all_labels): |
|
|
|
|
|
base_threshold = label_thresholds.get(label, default_threshold) |
|
|
|
|
|
adjusted_threshold = base_threshold * get_group_threshold_multiplier(label) |
|
|
|
|
|
if probs[i] >= adjusted_threshold and probs[i] >= min_confidence: |
|
|
preds.append((label, probs[i], adjusted_threshold)) |
|
|
|
|
|
|
|
|
preds.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
preds = preds[:max_predictions] |
|
|
|
|
|
inference_time = time.time() - start_time |
|
|
return [(label, prob) for label, prob, _ in preds], inference_time |
|
|
|
|
|
|
|
|
|
|
|
image_root = "/home/nagashimadaichi/dev/vectorize-ecg/data/images" |
|
|
xml_root = "/home/nagashimadaichi/dev/vectorize-ecg/data/xml" |
|
|
|
|
|
|
|
|
def parse_ecg_xml(xml_path): |
|
|
try: |
|
|
tree = ET.parse(xml_path) |
|
|
root = tree.getroot() |
|
|
ns = {'v3': 'urn:hl7-org:v3'} |
|
|
diagnoses = [] |
|
|
for interp in root.findall('.//v3:interpretationResult', ns): |
|
|
text = interp.find('v3:text', ns) |
|
|
if text is not None and text.text and text.text.strip(): |
|
|
diagnoses.append(text.text.strip()) |
|
|
codes = [] |
|
|
for group in root.findall('.//v3:justifiedDecisionGroup', ns): |
|
|
code_elem = group.find('v3:interpretationCode', ns) |
|
|
if code_elem is not None and code_elem.get('displayName') == 'ミネソタコード': |
|
|
for val in group.findall('.//v3:interpretationResult/v3:value', ns): |
|
|
if val.text and val.text.strip(): |
|
|
codes.append(val.text.strip()) |
|
|
return diagnoses, codes |
|
|
except: |
|
|
return [], [] |
|
|
|
|
|
def get_gt_labels(image_path, xml_root_dir): |
|
|
fn = os.path.basename(image_path) |
|
|
parts = fn.split('_') |
|
|
base, date = parts[0], parts[1] |
|
|
lead = parts[-1].split('.')[0] |
|
|
xml_file = f"{base}_{date}.xml" |
|
|
for root_dir, _, files in os.walk(xml_root_dir): |
|
|
if xml_file in files: |
|
|
xml_path = os.path.join(root_dir, xml_file) |
|
|
diagnoses, codes = parse_ecg_xml(xml_path) |
|
|
return [lead] + diagnoses + codes |
|
|
return [lead] |
|
|
|
|
|
def main(): |
|
|
|
|
|
test_dir = args.test_dir if args.test_dir else "/home/nagashimadaichi/dev/vectorize-ecg/data/images" |
|
|
xml_dir = args.xml_dir if args.xml_dir else "/home/nagashimadaichi/dev/vectorize-ecg/data/xml" |
|
|
|
|
|
all_images = [] |
|
|
for root, _, files in os.walk(test_dir): |
|
|
for f in files: |
|
|
if f.endswith('.png'): |
|
|
all_images.append(os.path.join(root, f)) |
|
|
|
|
|
if not all_images: |
|
|
print(f"Error: 画像ファイルが {test_dir} に見つかりません") |
|
|
return |
|
|
|
|
|
|
|
|
random.shuffle(all_images) |
|
|
test_images = all_images[:5] |
|
|
|
|
|
|
|
|
correct_exact = 0 |
|
|
total_f1 = 0 |
|
|
total_precision = 0 |
|
|
total_recall = 0 |
|
|
total_time = 0 |
|
|
total_images = len(test_images) |
|
|
|
|
|
for img_path in test_images: |
|
|
|
|
|
gt_labels = get_gt_labels(img_path, xml_dir) |
|
|
|
|
|
|
|
|
labels, elapsed = predict_image(img_path, max_predictions=15, min_confidence=0.05) |
|
|
total_time += elapsed |
|
|
|
|
|
|
|
|
if set(gt_labels) == set([label for label, _ in labels]): |
|
|
correct_exact += 1 |
|
|
|
|
|
|
|
|
gt_set = set(gt_labels) |
|
|
pred_set = set([label for label, _ in labels]) |
|
|
intersection = gt_set.intersection(pred_set) |
|
|
union = gt_set.union(pred_set) |
|
|
|
|
|
|
|
|
precision = len(intersection) / len(pred_set) if pred_set else 1.0 |
|
|
recall = len(intersection) / len(gt_set) if gt_set else 1.0 |
|
|
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 |
|
|
|
|
|
total_f1 += f1 |
|
|
total_precision += precision |
|
|
total_recall += recall |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print(f"画像: {os.path.basename(img_path)}") |
|
|
print(f"- 推論時間: {elapsed:.4f}秒") |
|
|
|
|
|
print("\n予測結果(確率 / 使用閾値):") |
|
|
for label, prob in labels: |
|
|
base_threshold = label_thresholds.get(label, default_threshold) |
|
|
adjusted_threshold = base_threshold * get_group_threshold_multiplier(label) |
|
|
print(f"- {label:<30} : {prob:.4f} / {adjusted_threshold:.4f}") |
|
|
|
|
|
print("\n正解ラベル:") |
|
|
for label in gt_labels: |
|
|
if label in [l for l, _ in labels]: |
|
|
print(f"- {label:<30} [正解]") |
|
|
else: |
|
|
print(f"- {label:<30} [未検出]") |
|
|
|
|
|
print("\n不正解の予測:") |
|
|
for label, prob in labels: |
|
|
if label not in gt_set: |
|
|
print(f"- {label:<30} : {prob:.4f} [過検出]") |
|
|
|
|
|
print(f"\n一致率: {len(intersection)}/{len(union)} = {len(intersection)/len(union):.4f}") |
|
|
print(f"適合率: {precision:.4f}, 再現率: {recall:.4f}, F1スコア: {f1:.4f}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print(f"全体の結果({total_images}枚):") |
|
|
print(f"完全一致率: {correct_exact}/{total_images} = {correct_exact/total_images:.4f}") |
|
|
print(f"平均適合率: {total_precision/total_images:.4f}") |
|
|
print(f"平均再現率: {total_recall/total_images:.4f}") |
|
|
print(f"平均F1スコア: {total_f1/total_images:.4f}") |
|
|
print(f"平均推論時間: {total_time/total_images:.4f}秒") |
|
|
print(f"推論速度: {total_images/total_time:.2f}枚/秒") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description='ECG画像の多ラベル予測') |
|
|
parser.add_argument('--test_dir', help='テスト画像のあるディレクトリ') |
|
|
parser.add_argument('--xml_dir', help='XMLファイルのあるディレクトリ') |
|
|
|
|
|
args = parser.parse_args() |
|
|
main() |