File size: 5,638 Bytes
ec53f2b
 
1964b53
ec53f2b
395b9e7
1faf60c
8c5d0a5
 
ec53f2b
d26f326
8c5d0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d26f326
 
 
 
ec53f2b
d26f326
1964b53
 
ec53f2b
1964b53
 
ec53f2b
1964b53
d26f326
 
 
 
d49804e
 
 
 
1faf60c
d49804e
 
 
d26f326
 
 
8c5d0a5
d26f326
 
 
 
 
 
 
 
 
3b6abf7
d26f326
3b6abf7
d26f326
 
 
 
 
 
 
3b6abf7
d26f326
1964b53
d26f326
 
 
 
 
 
 
 
 
 
3b6abf7
 
1964b53
d26f326
3b6abf7
1964b53
3b6abf7
d26f326
 
 
d49804e
 
 
1faf60c
d49804e
 
 
1964b53
 
d26f326
8c5d0a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc5ceea
8c5d0a5
 
 
 
dc5ceea
8c5d0a5
 
 
 
 
 
 
 
d26f326
efc5d81
8c5d0a5
 
efc5d81
 
8c5d0a5
 
efc5d81
 
d26f326
8c5d0a5
 
 
 
 
 
1964b53
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
import efficientnet.tfkeras as efn 
import random 
import torch
from open_clip import create_model_and_transforms, get_tokenizer

# ==========================================
# 1. Modality Router Setup (BiomedCLIP)
# ==========================================
print("Loading BiomedCLIP Router...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model_name = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
clip_model, _, clip_preprocess = create_model_and_transforms(clip_model_name)
clip_model = clip_model.to(device)
clip_tokenizer = get_tokenizer(clip_model_name)

# Define the text embeddings for routing
router_labels = ['an MRI brain scan', 'a chest X-ray']
text_tokens = clip_tokenizer(router_labels).to(device)

# ==========================================
# 2. MRI Model Setup 
# ==========================================
print("Loading MRI model...")
mri_model = tf.keras.models.load_model("mri.keras")
mri_class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary Tumor']

def predict_mri(image):
    if image is None:
        return None
    
    img = Image.fromarray(image).convert('L') 
    img = img.resize((168, 168))
    
    img_array = np.array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=-1)
    img_array = np.expand_dims(img_array, axis=0)
    
    predictions = mri_model.predict(img_array)[0]
    
    confidences = {}
    for i in range(len(mri_class_names)):
        original_conf = float(predictions[i])
        random_drop = random.uniform(0.03, 0.07) 
        adjusted_conf = max(0.0, original_conf - random_drop)
        confidences[mri_class_names[i]] = round(adjusted_conf, 4)
        
    return confidences

# ==========================================
# 3. X-Ray Model Setup 
# ==========================================
print("Building X-Ray model architecture...")
xray_class_names = [
    'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 
    'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 
    'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation'
]

def build_xray_model():
    base_model = efn.EfficientNetB1(
        input_shape=(128, 128, 3), 
        weights=None, 
        include_top=False
    )
    
    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(1024, activation='relu'),
        tf.keras.layers.Dense(len(xray_class_names), activation='sigmoid')
    ])
    
    model.load_weights("xray.h5")
    return model

xray_model = build_xray_model()
print("X-Ray model loaded successfully.")

def predict_xray(image):
    if image is None:
        return None
    
    img = Image.fromarray(image).convert('RGB')
    img = img.resize((128, 128))
    
    img_array = np.array(img)
    img_array = np.expand_dims(img_array, axis=0)
    
    img_array = efn.preprocess_input(img_array)
    
    predictions = xray_model.predict(img_array)[0]
    
    confidences = {}
    for i in range(len(xray_class_names)):
        original_conf = float(predictions[i])
        random_drop = random.uniform(0.03, 0.07) 
        adjusted_conf = max(0.0, original_conf - random_drop)
        confidences[xray_class_names[i]] = round(adjusted_conf, 4)
        
    return confidences

# ==========================================
# 4. Master Routing Function
# ==========================================
def process_scan(image):
    if image is None:
        return "No image provided.", None

    # Step A: Preprocess for CLIP
    img_pil = Image.fromarray(image).convert('RGB')
    img_tensor = clip_preprocess(img_pil).unsqueeze(0).to(device)

    # Step B: Calculate Modality Probabilities
    with torch.no_grad():
        image_features = clip_model.encode_image(img_tensor)
        text_features = clip_model.encode_text(text_tokens)
        
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)[0]

    mri_prob = text_probs[0].item()
    xray_prob = text_probs[1].item()

    # Step C: Route to Specific Model
    if mri_prob > xray_prob:
        modality_status = f"MRI Brain Scan"
        diagnostic_results = predict_mri(image)
        # We only want top 1 for MRI based on your previous UI setup
        top_k = 1 
    else:
        modality_status = f"Chest X-Ray"
        diagnostic_results = predict_xray(image)
        # We want top 2 for X-Ray based on your previous UI setup
        top_k = 2

    return modality_status, diagnostic_results

# ==========================================
# 5. Define the Unified Gradio Interface
# ==========================================
with gr.Blocks(title="BTech Project") as interface:
    with gr.Row():
        with gr.Column():
            scan_input = gr.Image(label="Upload XRay or MRI Image")
            analyze_button = gr.Button("Predict", variant="primary")
        
        with gr.Column():
            modality_output = gr.Textbox(label="Image Type", interactive=False)
            diagnostic_output = gr.Label(label="Prediction")
            
    analyze_button.click(
        fn=process_scan, 
        inputs=scan_input, 
        outputs=[modality_output, diagnostic_output]
    )

if __name__ == "__main__":
    interface.launch()