Spaces:
Running
Running
File size: 5,638 Bytes
ec53f2b 1964b53 ec53f2b 395b9e7 1faf60c 8c5d0a5 ec53f2b d26f326 8c5d0a5 d26f326 ec53f2b d26f326 1964b53 ec53f2b 1964b53 ec53f2b 1964b53 d26f326 d49804e 1faf60c d49804e d26f326 8c5d0a5 d26f326 3b6abf7 d26f326 3b6abf7 d26f326 3b6abf7 d26f326 1964b53 d26f326 3b6abf7 1964b53 d26f326 3b6abf7 1964b53 3b6abf7 d26f326 d49804e 1faf60c d49804e 1964b53 d26f326 8c5d0a5 dc5ceea 8c5d0a5 dc5ceea 8c5d0a5 d26f326 efc5d81 8c5d0a5 efc5d81 8c5d0a5 efc5d81 d26f326 8c5d0a5 1964b53 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
import efficientnet.tfkeras as efn
import random
import torch
from open_clip import create_model_and_transforms, get_tokenizer
# ==========================================
# 1. Modality Router Setup (BiomedCLIP)
# ==========================================
print("Loading BiomedCLIP Router...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model_name = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
clip_model, _, clip_preprocess = create_model_and_transforms(clip_model_name)
clip_model = clip_model.to(device)
clip_tokenizer = get_tokenizer(clip_model_name)
# Define the text embeddings for routing
router_labels = ['an MRI brain scan', 'a chest X-ray']
text_tokens = clip_tokenizer(router_labels).to(device)
# ==========================================
# 2. MRI Model Setup
# ==========================================
print("Loading MRI model...")
mri_model = tf.keras.models.load_model("mri.keras")
mri_class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary Tumor']
def predict_mri(image):
if image is None:
return None
img = Image.fromarray(image).convert('L')
img = img.resize((168, 168))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=-1)
img_array = np.expand_dims(img_array, axis=0)
predictions = mri_model.predict(img_array)[0]
confidences = {}
for i in range(len(mri_class_names)):
original_conf = float(predictions[i])
random_drop = random.uniform(0.03, 0.07)
adjusted_conf = max(0.0, original_conf - random_drop)
confidences[mri_class_names[i]] = round(adjusted_conf, 4)
return confidences
# ==========================================
# 3. X-Ray Model Setup
# ==========================================
print("Building X-Ray model architecture...")
xray_class_names = [
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation'
]
def build_xray_model():
base_model = efn.EfficientNetB1(
input_shape=(128, 128, 3),
weights=None,
include_top=False
)
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1024, activation='relu'),
tf.keras.layers.Dense(len(xray_class_names), activation='sigmoid')
])
model.load_weights("xray.h5")
return model
xray_model = build_xray_model()
print("X-Ray model loaded successfully.")
def predict_xray(image):
if image is None:
return None
img = Image.fromarray(image).convert('RGB')
img = img.resize((128, 128))
img_array = np.array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = efn.preprocess_input(img_array)
predictions = xray_model.predict(img_array)[0]
confidences = {}
for i in range(len(xray_class_names)):
original_conf = float(predictions[i])
random_drop = random.uniform(0.03, 0.07)
adjusted_conf = max(0.0, original_conf - random_drop)
confidences[xray_class_names[i]] = round(adjusted_conf, 4)
return confidences
# ==========================================
# 4. Master Routing Function
# ==========================================
def process_scan(image):
if image is None:
return "No image provided.", None
# Step A: Preprocess for CLIP
img_pil = Image.fromarray(image).convert('RGB')
img_tensor = clip_preprocess(img_pil).unsqueeze(0).to(device)
# Step B: Calculate Modality Probabilities
with torch.no_grad():
image_features = clip_model.encode_image(img_tensor)
text_features = clip_model.encode_text(text_tokens)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)[0]
mri_prob = text_probs[0].item()
xray_prob = text_probs[1].item()
# Step C: Route to Specific Model
if mri_prob > xray_prob:
modality_status = f"MRI Brain Scan"
diagnostic_results = predict_mri(image)
# We only want top 1 for MRI based on your previous UI setup
top_k = 1
else:
modality_status = f"Chest X-Ray"
diagnostic_results = predict_xray(image)
# We want top 2 for X-Ray based on your previous UI setup
top_k = 2
return modality_status, diagnostic_results
# ==========================================
# 5. Define the Unified Gradio Interface
# ==========================================
with gr.Blocks(title="BTech Project") as interface:
with gr.Row():
with gr.Column():
scan_input = gr.Image(label="Upload XRay or MRI Image")
analyze_button = gr.Button("Predict", variant="primary")
with gr.Column():
modality_output = gr.Textbox(label="Image Type", interactive=False)
diagnostic_output = gr.Label(label="Prediction")
analyze_button.click(
fn=process_scan,
inputs=scan_input,
outputs=[modality_output, diagnostic_output]
)
if __name__ == "__main__":
interface.launch() |