GSMK commited on
Commit
d3b8219
·
verified ·
1 Parent(s): 90df017

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -17
app.py CHANGED
@@ -11,6 +11,7 @@ from transformers import (
11
 
12
  import open_clip
13
 
 
14
  st.set_page_config(page_title="Multi-Domain Zero Shot AI", layout="wide")
15
 
16
  st.title("Multi-Domain Zero Shot Image Classification")
@@ -26,15 +27,17 @@ device = "cpu"
26
  @st.cache_resource
27
  def load_models():
28
 
29
- # -------- BIOMED CLIP (via Transformers CLIP) --------
30
- biomed_model = CLIPModel.from_pretrained(
31
- "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
32
- ).to(device).eval()
33
 
34
- biomed_processor = CLIPProcessor.from_pretrained(
35
- "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
36
  )
37
 
 
 
38
  # -------- REMOTE CLIP --------
39
  remote_model, _, remote_preprocess = open_clip.create_model_and_transforms(
40
  "ViT-B-32",
@@ -65,7 +68,8 @@ def load_models():
65
 
66
  return (
67
  biomed_model,
68
- biomed_processor,
 
69
  remote_model,
70
  remote_preprocess,
71
  remote_tokenizer,
@@ -78,7 +82,8 @@ def load_models():
78
 
79
  (
80
  biomed_model,
81
- biomed_processor,
 
82
  remote_model,
83
  remote_preprocess,
84
  remote_tokenizer,
@@ -174,22 +179,19 @@ if uploaded_file:
174
 
175
 
176
  # --------------------------------------------------
177
- # MEDICAL / SKIN (BIOMEDCLIP)
178
  # --------------------------------------------------
179
 
180
  if dataset_key in ["medical", "skin_disease"]:
181
 
182
- inputs = biomed_processor(
183
- text=text_queries,
184
- images=image,
185
- return_tensors="pt",
186
- padding=True
187
- ).to(device)
188
 
189
  with torch.no_grad():
190
- outputs = biomed_model(**inputs)
 
191
 
192
- similarity = outputs.logits_per_image.softmax(dim=1)
193
 
194
 
195
  # --------------------------------------------------
 
11
 
12
  import open_clip
13
 
14
+
15
  st.set_page_config(page_title="Multi-Domain Zero Shot AI", layout="wide")
16
 
17
  st.title("Multi-Domain Zero Shot Image Classification")
 
27
  @st.cache_resource
28
  def load_models():
29
 
30
+ # -------- BIOMED CLIP --------
31
+ biomed_model, _, biomed_preprocess = open_clip.create_model_and_transforms(
32
+ "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
33
+ )
34
 
35
+ biomed_tokenizer = open_clip.get_tokenizer(
36
+ "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
37
  )
38
 
39
+ biomed_model = biomed_model.to(device).eval()
40
+
41
  # -------- REMOTE CLIP --------
42
  remote_model, _, remote_preprocess = open_clip.create_model_and_transforms(
43
  "ViT-B-32",
 
68
 
69
  return (
70
  biomed_model,
71
+ biomed_preprocess,
72
+ biomed_tokenizer,
73
  remote_model,
74
  remote_preprocess,
75
  remote_tokenizer,
 
82
 
83
  (
84
  biomed_model,
85
+ biomed_preprocess,
86
+ biomed_tokenizer,
87
  remote_model,
88
  remote_preprocess,
89
  remote_tokenizer,
 
179
 
180
 
181
  # --------------------------------------------------
182
+ # MEDICAL + SKIN (BIOMEDCLIP)
183
  # --------------------------------------------------
184
 
185
  if dataset_key in ["medical", "skin_disease"]:
186
 
187
+ img = biomed_preprocess(image).unsqueeze(0).to(device)
188
+ text = biomed_tokenizer(text_queries)
 
 
 
 
189
 
190
  with torch.no_grad():
191
+ image_features = biomed_model.encode_image(img)
192
+ text_features = biomed_model.encode_text(text)
193
 
194
+ similarity = (image_features @ text_features.T).softmax(dim=-1)
195
 
196
 
197
  # --------------------------------------------------