Spaces:
Sleeping
Sleeping
File size: 4,798 Bytes
b412062 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import sys
import pandas as pd
import logging
import argparse
import numpy as np
from tqdm import tqdm
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
try:
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)
except ImportError:
pass
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from model import RawImageModel, PrecomputedModel
from dicom_utils import read_dicom_image
from PIL import Image
def main():
parser = argparse.ArgumentParser(description="Evaluate on Kaggle DICOM Dataset")
parser.add_argument("--csv", default="data/kaggle/labels.csv", help="Path to labels CSV")
parser.add_argument("--data-dir", default="data/kaggle", help="Root directory for images if relative paths in CSV")
parser.add_argument("--output", default="results/kaggle_predictions.csv", help="Output predictions file")
args = parser.parse_args()
# Create output directory
os.makedirs(os.path.dirname(args.output), exist_ok=True)
# Load dataset
try:
df = pd.read_csv(args.csv)
logger.info(f"Loaded {len(df)} records from {args.csv}")
except Exception as e:
logger.error(f"Failed to load CSV: {e}")
return
# Check for file column
file_col = 'file' if 'file' in df.columns else 'dicom_file' # Adapt to potential column names
if file_col not in df.columns and 'file' not in df.columns:
# Fallback inspection or error
logger.error(f"Missing file column in CSV. Found: {df.columns}")
return
# Initialize Models
try:
# We need PrecomputedModel for text embeddings (labels)
precomputed_model = PrecomputedModel()
# We need RawImageModel for the images
raw_model = RawImageModel()
logger.info("Models loaded successfully.")
except Exception as e:
logger.fatal(f"Failed to initialize models: {e}")
return
# Get text embeddings for diagnosis
diagnosis = 'PNEUMOTHORAX'
try:
# Hardcoded prompts matching main.py
pos_txt = 'small pneumothorax'
neg_txt = 'no pneumothorax'
pos_emb, neg_emb = precomputed_model.get_diagnosis_embeddings(pos_txt, neg_txt)
except Exception as e:
logger.fatal(f"Failed to get text embeddings: {e}")
return
predictions = []
# Iterate and predict
print(f"Running inference for {diagnosis} on {len(df)} images...")
temp_path = "temp_inference.png"
for _, row in tqdm(df.iterrows(), total=len(df)):
file_path = row[file_col]
# Construct full path
full_path = os.path.join(args.data_dir, file_path) if not os.path.isabs(file_path) else file_path
# Check if file exists
if not os.path.exists(full_path):
logger.warning(f"File not found: {full_path}")
predictions.append({
'file': file_path,
'true_label': None,
'pneumothorax_score': None,
'error': 'File not found'
})
continue
true_label = row.get('label', row.get('PNEUMOTHORAX', 'Unknown'))
try:
# 1. Read DICOM
image_array = read_dicom_image(full_path)
# 2. Save as temp PNG (Required by RawImageModel/TF pipeline currently)
Image.fromarray(image_array).save(temp_path)
# 3. Compute Image Embedding
img_emb = raw_model.compute_embeddings(temp_path)
# 4. Compute Zero-Shot Score
score = PrecomputedModel.zero_shot(img_emb, pos_emb, neg_emb)
predictions.append({
'file': file_path,
'true_label': true_label,
'pneumothorax_score': float(score)
})
except Exception as e:
# logger.warning(f"Failed to process {file_path}: {e}")
predictions.append({
'file': file_path,
'true_label': true_label,
'pneumothorax_score': None,
'error': str(e)
})
# Incremental Save every 10 items
if len(predictions) % 10 == 0:
pd.DataFrame(predictions).to_csv(args.output, index=False)
# Final Save
results_df = pd.DataFrame(predictions)
results_df.to_csv(args.output, index=False)
logger.info(f"Predictions saved to {args.output}")
# Cleanup
if os.path.exists("temp_inference.png"):
os.remove("temp_inference.png")
if __name__ == "__main__":
main()
|