File size: 4,099 Bytes
b412062
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import argparse
import logging
import sys

# Suppress TensorFlow and system warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # FATAL
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd

# Configure logging first
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Suppress absl logging from TensorFlow
try:
    import absl.logging
    absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
    pass

# Suppress TensorFlow Python logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)

from model import PrecomputedModel, RawImageModel
from evaluate import evaluate_predictions

DIAGNOSIS_PROMPTS = {
    'AIRSPACE_OPACITY': ('Airspace Opacity', 'no evidence of airspace disease'),
    'PNEUMOTHORAX': ('small pneumothorax', 'no pneumothorax'),
    'EFFUSION': ('large pleural effusion', 'no pleural effusion'),
    'PULMONARY_EDEMA': ('moderate pulmonary edema', 'no pulmonary edema'),
}

def main():
    parser = argparse.ArgumentParser(description="Zero-Shot Chest X-Ray Classification")
    parser.add_argument("--diagnosis", type=str, choices=DIAGNOSIS_PROMPTS.keys(), required=True, help="Diagnosis to evaluate")
    parser.add_argument("--data-dir", type=str, default="data", help="Path to data directory")
    parser.add_argument("--raw-image", type=str, help="Path to a raw image file for inference (optional)")
    args = parser.parse_args()

    # Get prompts
    pos_txt, neg_txt = DIAGNOSIS_PROMPTS[args.diagnosis]
    logger.info(f"Diagnosis: {args.diagnosis}")
    logger.info(f"Positive query: '{pos_txt}'")
    logger.info(f"Negative query: '{neg_txt}'")

    # Load precomputed model for text embeddings (and image embeddings if no raw image)
    precomputed_model = PrecomputedModel(data_dir=args.data_dir)
    pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)

    if args.raw_image:
        # Raw Image Inference Mode
        logger.info(f"Running inference on raw image: {args.raw_image}")
        raw_model = RawImageModel()
        try:
            image_emb = raw_model.compute_embeddings(args.raw_image)
            # image_emb shape is likely (1, 32, 128) or (32, 128)
            # PrecomputedModel.zero_shot expects flattened or (32, 128)
            
            score = PrecomputedModel.zero_shot(image_emb, pos_emb, neg_emb)
            logger.info(f"Zero-shot score for {args.raw_image}: {score:.4f}")
            
            # Since we only have one image, we can't calculate AUC meaningfully 
            # unless we run it against the full validation set which takes time.
            # For this demo, just output the score.
            print(f"Score for {args.diagnosis}: {score}")
            
        except Exception as e:
            logger.error(f"Failed to process raw image: {e}")
            sys.exit(1)
            
    else:
        # Precomputed Embeddings Evaluation Mode (Full Dataset)
        logger.info("Running evaluation on full precomputed dataset...")
        
        # Filter labels for the target diagnosis (0 or 1)
        labels_df = precomputed_model.labels
        target_df = labels_df[labels_df[args.diagnosis].isin([0, 1])][['image_id', args.diagnosis]].copy()
        
        image_ids = target_df['image_id'].tolist()
        true_labels = target_df[args.diagnosis].tolist()
        
        # Compute scores
        valid_ids, scores = precomputed_model.compute_scores(image_ids, pos_emb, neg_emb)
        
        # Filter labels to match valid_ids found in embeddings
        final_labels = []
        for img_id, label in zip(image_ids, true_labels):
            if img_id in valid_ids:
                final_labels.append(label)
        
        if not scores:
            logger.error("No valid scores computed. Check embedding match.")
            sys.exit(1)
            
        # Evaluate
        evaluate_predictions(scores, final_labels, args.diagnosis)

if __name__ == "__main__":
    main()