import tensorflow as tf import numpy as np import torch import open_clip import gradio as gr from PIL import Image # CNN MODEL cnn_model = tf.keras.models.load_model("lung_cancer_model.h5") class_names = [ "adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib", "large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa", "normal", "squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa" ] # BIOMEDCLIP device = "cuda" if torch.cuda.is_available() else "cpu" clip_model, _, preprocess = open_clip.create_model_and_transforms( "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" ) tokenizer = open_clip.get_tokenizer( "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224" ) clip_model = clip_model.to(device).eval() clip_texts = [ "CT scan of normal lung tissue", "CT scan showing lung adenocarcinoma", "CT scan showing large cell carcinoma", "CT scan showing squamous cell carcinoma" ] # CNN PREDICTION def cnn_predict(image): img = image.resize((224, 224)) img_array = np.array(img) / 255.0 img_array = np.expand_dims(img_array, axis=0) probs = cnn_model.predict(img_array, verbose=0)[0] pred = class_names[np.argmax(probs)] return pred, probs # CLIP PREDICTION def clip_predict(image): image_input = preprocess(image.convert("RGB")).unsqueeze(0).to(device) text_input = tokenizer(clip_texts).to(device) with torch.no_grad(): img_feat = clip_model.encode_image(image_input) txt_feat = clip_model.encode_text(text_input) img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True) similarity = img_feat @ txt_feat.T pred = torch.argmax(similarity, dim=-1).item() return clip_texts[pred] # FINAL COMBINED FUNCTION def predict(image): cnn_pred, cnn_probs = cnn_predict(image) clip_result = clip_predict(image) return { "CNN Prediction": cnn_pred, "Confidence": float(np.max(cnn_probs)), "BiomedCLIP Insight": clip_result } # GRADIO APP demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="json", title="Multimodal Lung Cancer AI", description="CNN + BiomedCLIP combined system" ) demo.launch()