File size: 4,798 Bytes
b412062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import sys
import pandas as pd
import logging
import argparse
import numpy as np
from tqdm import tqdm

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try:
    import absl.logging
    absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
    pass
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)

from model import RawImageModel, PrecomputedModel
from dicom_utils import read_dicom_image
from PIL import Image

def main():
    parser = argparse.ArgumentParser(description="Evaluate on Kaggle DICOM Dataset")
    parser.add_argument("--csv", default="data/kaggle/labels.csv", help="Path to labels CSV")
    parser.add_argument("--data-dir", default="data/kaggle", help="Root directory for images if relative paths in CSV")
    parser.add_argument("--output", default="results/kaggle_predictions.csv", help="Output predictions file")
    args = parser.parse_args()

    # Create output directory
    os.makedirs(os.path.dirname(args.output), exist_ok=True)

    # Load dataset
    try:
        df = pd.read_csv(args.csv)
        logger.info(f"Loaded {len(df)} records from {args.csv}")
    except Exception as e:
        logger.error(f"Failed to load CSV: {e}")
        return

    # Check for file column
    file_col = 'file' if 'file' in df.columns else 'dicom_file' # Adapt to potential column names
    if file_col not in df.columns and 'file' not in df.columns:
         # Fallback inspection or error
         logger.error(f"Missing file column in CSV. Found: {df.columns}")
         return
    
    # Initialize Models
    try:
        # We need PrecomputedModel for text embeddings (labels)
        precomputed_model = PrecomputedModel()
        
        # We need RawImageModel for the images
        raw_model = RawImageModel()
        logger.info("Models loaded successfully.")
    except Exception as e:
        logger.fatal(f"Failed to initialize models: {e}")
        return

    # Get text embeddings for diagnosis
    diagnosis = 'PNEUMOTHORAX'
    try:
        # Hardcoded prompts matching main.py
        pos_txt = 'small pneumothorax'
        neg_txt = 'no pneumothorax'
        pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
    except Exception as e:
        logger.fatal(f"Failed to get text embeddings: {e}")
        return

    predictions = []
    
    # Iterate and predict
    print(f"Running inference for {diagnosis} on {len(df)} images...")
    
    temp_path = "temp_inference.png"
    
    for _, row in tqdm(df.iterrows(), total=len(df)):
        file_path = row[file_col]
        # Construct full path
        full_path = os.path.join(args.data_dir, file_path) if not os.path.isabs(file_path) else file_path
        
        # Check if file exists
        if not os.path.exists(full_path):
             logger.warning(f"File not found: {full_path}")
             predictions.append({
                'file': file_path,
                'true_label': None,
                'pneumothorax_score': None,
                'error': 'File not found'
            })
             continue

        true_label = row.get('label', row.get('PNEUMOTHORAX', 'Unknown'))

        try:
            # 1. Read DICOM
            image_array = read_dicom_image(full_path)
            
            # 2. Save as temp PNG (Required by RawImageModel/TF pipeline currently)
            Image.fromarray(image_array).save(temp_path)
            
            # 3. Compute Image Embedding
            img_emb = raw_model.compute_embeddings(temp_path)
            
            # 4. Compute Zero-Shot Score
            score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
            
            predictions.append({
                'file': file_path,
                'true_label': true_label,
                'pneumothorax_score': float(score)
            })
            
        except Exception as e:
            # logger.warning(f"Failed to process {file_path}: {e}")
            predictions.append({
                'file': file_path,
                'true_label': true_label,
                'pneumothorax_score': None,
                'error': str(e)
            })
        
        # Incremental Save every 10 items
        if len(predictions) % 10 == 0:
            pd.DataFrame(predictions).to_csv(args.output, index=False)

    # Final Save
    results_df = pd.DataFrame(predictions)
    results_df.to_csv(args.output, index=False)
    logger.info(f"Predictions saved to {args.output}")

    # Cleanup
    if os.path.exists("temp_inference.png"):
        os.remove("temp_inference.png")

if __name__ == "__main__":
    main()