shy-33's picture
Update app.py
f1748a2 verified
import tensorflow as tf
import numpy as np
import torch
import open_clip
import gradio as gr
from PIL import Image
# CNN MODEL
cnn_model = tf.keras.models.load_model("lung_cancer_model.h5")
class_names = [
"adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib",
"large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa",
"normal",
"squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa"
]
# BIOMEDCLIP
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)
tokenizer = open_clip.get_tokenizer(
"hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
)
clip_model = clip_model.to(device).eval()
clip_texts = [
"CT scan of normal lung tissue",
"CT scan showing lung adenocarcinoma",
"CT scan showing large cell carcinoma",
"CT scan showing squamous cell carcinoma"
]
# CNN PREDICTION
def cnn_predict(image):
img = image.resize((224, 224))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
probs = cnn_model.predict(img_array, verbose=0)[0]
pred = class_names[np.argmax(probs)]
return pred, probs
# CLIP PREDICTION
def clip_predict(image):
image_input = preprocess(image.convert("RGB")).unsqueeze(0).to(device)
text_input = tokenizer(clip_texts).to(device)
with torch.no_grad():
img_feat = clip_model.encode_image(image_input)
txt_feat = clip_model.encode_text(text_input)
img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
similarity = img_feat @ txt_feat.T
pred = torch.argmax(similarity, dim=-1).item()
return clip_texts[pred]
# FINAL COMBINED FUNCTION
def predict(image):
cnn_pred, cnn_probs = cnn_predict(image)
clip_result = clip_predict(image)
return {
"CNN Prediction": cnn_pred,
"Confidence": float(np.max(cnn_probs)),
"BiomedCLIP Insight": clip_result
}
# GRADIO APP
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="json",
title="Multimodal Lung Cancer AI",
description="CNN + BiomedCLIP combined system"
)
demo.launch()