GSMK commited on
Commit
00fa93d
·
verified ·
1 Parent(s): 58150c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -21
app.py CHANGED
@@ -4,16 +4,16 @@ from PIL import Image
4
  import open_clip
5
 
6
  from transformers import (
 
 
7
  BlipProcessor,
8
- BlipForConditionalGeneration,
9
- AutoProcessor,
10
- AutoModel
11
  )
12
 
13
- st.set_page_config(page_title="Zero Shot Image Classification", layout="wide")
14
 
15
- st.title("Zero Shot Image Classification")
16
- st.write("BiomedCLIP + RemoteCLIP + AgriCLIP + BLIP")
17
 
18
  device = "cpu"
19
 
@@ -25,7 +25,7 @@ device = "cpu"
25
  @st.cache_resource
26
  def load_models():
27
 
28
- # -------- BIOMEDCLIP --------
29
  biomed_model, _, biomed_preprocess = open_clip.create_model_and_transforms(
30
  "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
31
  )
@@ -36,7 +36,8 @@ def load_models():
36
 
37
  biomed_model = biomed_model.to(device).eval()
38
 
39
- # -------- REMOTECLIP --------
 
40
  remote_model, _, remote_preprocess = open_clip.create_model_and_transforms(
41
  "ViT-B-32",
42
  pretrained="laion2b_s34b_b79k"
@@ -45,11 +46,18 @@ def load_models():
45
  remote_tokenizer = open_clip.get_tokenizer("ViT-B-32")
46
  remote_model = remote_model.to(device).eval()
47
 
48
- # -------- AGRICLIP --------
49
- agri_processor = AutoProcessor.from_pretrained("hmellor/agriclip")
50
- agri_model = AutoModel.from_pretrained("hmellor/agriclip").to(device).eval()
51
 
52
- # -------- BLIP --------
 
 
 
 
 
 
 
 
 
 
53
  blip_processor = BlipProcessor.from_pretrained(
54
  "Salesforce/blip-image-captioning-base"
55
  )
@@ -58,6 +66,7 @@ def load_models():
58
  "Salesforce/blip-image-captioning-base"
59
  ).to(device).eval()
60
 
 
61
  return (
62
  biomed_model,
63
  biomed_preprocess,
@@ -65,8 +74,8 @@ def load_models():
65
  remote_model,
66
  remote_preprocess,
67
  remote_tokenizer,
68
- agri_model,
69
- agri_processor,
70
  blip_processor,
71
  blip_model
72
  )
@@ -79,15 +88,15 @@ def load_models():
79
  remote_model,
80
  remote_preprocess,
81
  remote_tokenizer,
82
- agri_model,
83
- agri_processor,
84
  blip_processor,
85
  blip_model
86
  ) = load_models()
87
 
88
 
89
  # --------------------------------------------------
90
- # DATASETS
91
  # --------------------------------------------------
92
 
93
  DATASETS = {
@@ -99,7 +108,7 @@ DATASETS = {
99
 
100
 
101
  # --------------------------------------------------
102
- # PROMPT TEMPLATES WITH SUBCLASSES
103
  # --------------------------------------------------
104
 
105
  templates = {
@@ -291,6 +300,7 @@ if uploaded_file:
291
 
292
  similarity = (image_features @ text_features.T).softmax(dim=-1)
293
 
 
294
  elif dataset_key == "satellite":
295
 
296
  img = remote_preprocess(image).unsqueeze(0).to(device)
@@ -302,9 +312,10 @@ if uploaded_file:
302
 
303
  similarity = (image_features @ text_features.T).softmax(dim=-1)
304
 
 
305
  elif dataset_key == "agriculture":
306
 
307
- inputs = agri_processor(
308
  text=text_queries,
309
  images=image,
310
  return_tensors="pt",
@@ -312,7 +323,7 @@ if uploaded_file:
312
  ).to(device)
313
 
314
  with torch.no_grad():
315
- outputs = agri_model(**inputs)
316
 
317
  similarity = outputs.logits_per_image.softmax(dim=1)
318
 
@@ -326,7 +337,7 @@ if uploaded_file:
326
 
327
 
328
  # --------------------------------------------------
329
- # BLIP CAPTION
330
  # --------------------------------------------------
331
 
332
  blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
 
4
  import open_clip
5
 
6
  from transformers import (
7
+ CLIPProcessor,
8
+ CLIPModel,
9
  BlipProcessor,
10
+ BlipForConditionalGeneration
 
 
11
  )
12
 
13
+ st.set_page_config(page_title="Multi-Domain Zero Shot AI", layout="wide")
14
 
15
+ st.title("Multi-Domain Zero Shot Image Classification")
16
+ st.write("BiomedCLIP + RemoteCLIP + CLIP + BLIP")
17
 
18
  device = "cpu"
19
 
 
25
  @st.cache_resource
26
  def load_models():
27
 
28
+ # ---------- BIOMEDCLIP ----------
29
  biomed_model, _, biomed_preprocess = open_clip.create_model_and_transforms(
30
  "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
31
  )
 
36
 
37
  biomed_model = biomed_model.to(device).eval()
38
 
39
+
40
+ # ---------- REMOTECLIP ----------
41
  remote_model, _, remote_preprocess = open_clip.create_model_and_transforms(
42
  "ViT-B-32",
43
  pretrained="laion2b_s34b_b79k"
 
46
  remote_tokenizer = open_clip.get_tokenizer("ViT-B-32")
47
  remote_model = remote_model.to(device).eval()
48
 
 
 
 
49
 
50
+ # ---------- CLIP (AGRICULTURE) ----------
51
+ clip_model = CLIPModel.from_pretrained(
52
+ "openai/clip-vit-base-patch32"
53
+ ).to(device).eval()
54
+
55
+ clip_processor = CLIPProcessor.from_pretrained(
56
+ "openai/clip-vit-base-patch32"
57
+ )
58
+
59
+
60
+ # ---------- BLIP ----------
61
  blip_processor = BlipProcessor.from_pretrained(
62
  "Salesforce/blip-image-captioning-base"
63
  )
 
66
  "Salesforce/blip-image-captioning-base"
67
  ).to(device).eval()
68
 
69
+
70
  return (
71
  biomed_model,
72
  biomed_preprocess,
 
74
  remote_model,
75
  remote_preprocess,
76
  remote_tokenizer,
77
+ clip_model,
78
+ clip_processor,
79
  blip_processor,
80
  blip_model
81
  )
 
88
  remote_model,
89
  remote_preprocess,
90
  remote_tokenizer,
91
+ clip_model,
92
+ clip_processor,
93
  blip_processor,
94
  blip_model
95
  ) = load_models()
96
 
97
 
98
  # --------------------------------------------------
99
+ # DATASET CLASSES
100
  # --------------------------------------------------
101
 
102
  DATASETS = {
 
108
 
109
 
110
  # --------------------------------------------------
111
+ # PROMPT TEMPLATES (SUBCLASS PROMPTS)
112
  # --------------------------------------------------
113
 
114
  templates = {
 
300
 
301
  similarity = (image_features @ text_features.T).softmax(dim=-1)
302
 
303
+
304
  elif dataset_key == "satellite":
305
 
306
  img = remote_preprocess(image).unsqueeze(0).to(device)
 
312
 
313
  similarity = (image_features @ text_features.T).softmax(dim=-1)
314
 
315
+
316
  elif dataset_key == "agriculture":
317
 
318
+ inputs = clip_processor(
319
  text=text_queries,
320
  images=image,
321
  return_tensors="pt",
 
323
  ).to(device)
324
 
325
  with torch.no_grad():
326
+ outputs = clip_model(**inputs)
327
 
328
  similarity = outputs.logits_per_image.softmax(dim=1)
329
 
 
337
 
338
 
339
  # --------------------------------------------------
340
+ # BLIP IMAGE CAPTION
341
  # --------------------------------------------------
342
 
343
  blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)