shy-33 commited on
Commit
47d9236
·
verified ·
1 Parent(s): 383153f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+ import gradio as gr
4
+ from PIL import Image
5
+
6
+ # Device
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ # Load model
10
+ model, _, preprocess = open_clip.create_model_and_transforms(
11
+ 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
12
+ )
13
+
14
+ tokenizer = open_clip.get_tokenizer(
15
+ 'hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224'
16
+ )
17
+
18
+ model = model.to(device).eval()
19
+
20
+ # Labels
21
+ labels = [
22
+ "CT scan of normal lung tissue",
23
+ "CT scan showing lung adenocarcinoma",
24
+ "CT scan showing large cell carcinoma",
25
+ "CT scan showing squamous cell carcinoma"
26
+ ]
27
+
28
+ class_names = [
29
+ "adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib",
30
+ "large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa",
31
+ "normal",
32
+ "squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa"]
33
+
34
+ # Prediction function
35
+ def predict(image):
36
+ image = preprocess(image.convert("RGB")).unsqueeze(0).to(device)
37
+ text = tokenizer(labels).to(device)
38
+
39
+ with torch.no_grad():
40
+ img_feat = model.encode_image(image)
41
+ txt_feat = model.encode_text(text)
42
+
43
+ img_feat /= img_feat.norm(dim=-1, keepdim=True)
44
+ txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
45
+
46
+ similarity = img_feat @ txt_feat.T
47
+ probs = similarity.softmax(dim=-1)[0]
48
+
49
+ pred = torch.argmax(probs).item()
50
+ return class_names[pred],probs.cpu().tolist()
51
+
52
+ def app(image):
53
+ pred, probs = predict(image)
54
+
55
+ return {
56
+ "Prediction": pred,
57
+ "Normal": float(probs[0]),
58
+ "adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib": float(probs[1]),
59
+ "large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa": float(probs[2]),
60
+ "squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa": float(probs[3])
61
+ }
62
+
63
+ demo = gr.Interface(
64
+ fn=app,
65
+ inputs=gr.Image(type="pil"),
66
+ outputs="json",
67
+ title="Multimodal Lung Cancer AI",
68
+ description="Upload a CT scan image to classify lung cancer type using a vision-language AI model."
69
+ )
70
+
71
+ demo.launch()