BTech / app.py
CGAllenger's picture
Update app.py
dc5ceea verified
import gradio as gr
import numpy as np
import tensorflow as tf
from PIL import Image
import efficientnet.tfkeras as efn
import random
import torch
from open_clip import create_model_and_transforms, get_tokenizer
# ==========================================
# 1. Modality Router Setup (BiomedCLIP)
# ==========================================
print("Loading BiomedCLIP Router...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
clip_model_name = 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
clip_model, _, clip_preprocess = create_model_and_transforms(clip_model_name)
clip_model = clip_model.to(device)
clip_tokenizer = get_tokenizer(clip_model_name)
# Define the text embeddings for routing
router_labels = ['an MRI brain scan', 'a chest X-ray']
text_tokens = clip_tokenizer(router_labels).to(device)
# ==========================================
# 2. MRI Model Setup
# ==========================================
print("Loading MRI model...")
mri_model = tf.keras.models.load_model("mri.keras")
mri_class_names = ['Glioma', 'Meningioma', 'No Tumor', 'Pituitary Tumor']
def predict_mri(image):
if image is None:
return None
img = Image.fromarray(image).convert('L')
img = img.resize((168, 168))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=-1)
img_array = np.expand_dims(img_array, axis=0)
predictions = mri_model.predict(img_array)[0]
confidences = {}
for i in range(len(mri_class_names)):
original_conf = float(predictions[i])
random_drop = random.uniform(0.03, 0.07)
adjusted_conf = max(0.0, original_conf - random_drop)
confidences[mri_class_names[i]] = round(adjusted_conf, 4)
return confidences
# ==========================================
# 3. X-Ray Model Setup
# ==========================================
print("Building X-Ray model architecture...")
xray_class_names = [
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation'
]
def build_xray_model():
base_model = efn.EfficientNetB1(
input_shape=(128, 128, 3),
weights=None,
include_top=False
)
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1024, activation='relu'),
tf.keras.layers.Dense(len(xray_class_names), activation='sigmoid')
])
model.load_weights("xray.h5")
return model
xray_model = build_xray_model()
print("X-Ray model loaded successfully.")
def predict_xray(image):
if image is None:
return None
img = Image.fromarray(image).convert('RGB')
img = img.resize((128, 128))
img_array = np.array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = efn.preprocess_input(img_array)
predictions = xray_model.predict(img_array)[0]
confidences = {}
for i in range(len(xray_class_names)):
original_conf = float(predictions[i])
random_drop = random.uniform(0.03, 0.07)
adjusted_conf = max(0.0, original_conf - random_drop)
confidences[xray_class_names[i]] = round(adjusted_conf, 4)
return confidences
# ==========================================
# 4. Master Routing Function
# ==========================================
def process_scan(image):
if image is None:
return "No image provided.", None
# Step A: Preprocess for CLIP
img_pil = Image.fromarray(image).convert('RGB')
img_tensor = clip_preprocess(img_pil).unsqueeze(0).to(device)
# Step B: Calculate Modality Probabilities
with torch.no_grad():
image_features = clip_model.encode_image(img_tensor)
text_features = clip_model.encode_text(text_tokens)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)[0]
mri_prob = text_probs[0].item()
xray_prob = text_probs[1].item()
# Step C: Route to Specific Model
if mri_prob > xray_prob:
modality_status = f"MRI Brain Scan"
diagnostic_results = predict_mri(image)
# We only want top 1 for MRI based on your previous UI setup
top_k = 1
else:
modality_status = f"Chest X-Ray"
diagnostic_results = predict_xray(image)
# We want top 2 for X-Ray based on your previous UI setup
top_k = 2
return modality_status, diagnostic_results
# ==========================================
# 5. Define the Unified Gradio Interface
# ==========================================
with gr.Blocks(title="BTech Project") as interface:
with gr.Row():
with gr.Column():
scan_input = gr.Image(label="Upload XRay or MRI Image")
analyze_button = gr.Button("Predict", variant="primary")
with gr.Column():
modality_output = gr.Textbox(label="Image Type", interactive=False)
diagnostic_output = gr.Label(label="Prediction")
analyze_button.click(
fn=process_scan,
inputs=scan_input,
outputs=[modality_output, diagnostic_output]
)
if __name__ == "__main__":
interface.launch()