Update app.py
Browse files
app.py
CHANGED
|
@@ -4,16 +4,16 @@ from PIL import Image
|
|
| 4 |
import open_clip
|
| 5 |
|
| 6 |
from transformers import (
|
|
|
|
|
|
|
| 7 |
BlipProcessor,
|
| 8 |
-
BlipForConditionalGeneration
|
| 9 |
-
AutoProcessor,
|
| 10 |
-
AutoModel
|
| 11 |
)
|
| 12 |
|
| 13 |
-
st.set_page_config(page_title="Zero Shot
|
| 14 |
|
| 15 |
-
st.title("Zero Shot Image Classification")
|
| 16 |
-
st.write("BiomedCLIP + RemoteCLIP +
|
| 17 |
|
| 18 |
device = "cpu"
|
| 19 |
|
|
@@ -25,7 +25,7 @@ device = "cpu"
|
|
| 25 |
@st.cache_resource
|
| 26 |
def load_models():
|
| 27 |
|
| 28 |
-
# -------- BIOMEDCLIP --------
|
| 29 |
biomed_model, _, biomed_preprocess = open_clip.create_model_and_transforms(
|
| 30 |
"hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
| 31 |
)
|
|
@@ -36,7 +36,8 @@ def load_models():
|
|
| 36 |
|
| 37 |
biomed_model = biomed_model.to(device).eval()
|
| 38 |
|
| 39 |
-
|
|
|
|
| 40 |
remote_model, _, remote_preprocess = open_clip.create_model_and_transforms(
|
| 41 |
"ViT-B-32",
|
| 42 |
pretrained="laion2b_s34b_b79k"
|
|
@@ -45,11 +46,18 @@ def load_models():
|
|
| 45 |
remote_tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
| 46 |
remote_model = remote_model.to(device).eval()
|
| 47 |
|
| 48 |
-
# -------- AGRICLIP --------
|
| 49 |
-
agri_processor = AutoProcessor.from_pretrained("hmellor/agriclip")
|
| 50 |
-
agri_model = AutoModel.from_pretrained("hmellor/agriclip").to(device).eval()
|
| 51 |
|
| 52 |
-
# --------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
blip_processor = BlipProcessor.from_pretrained(
|
| 54 |
"Salesforce/blip-image-captioning-base"
|
| 55 |
)
|
|
@@ -58,6 +66,7 @@ def load_models():
|
|
| 58 |
"Salesforce/blip-image-captioning-base"
|
| 59 |
).to(device).eval()
|
| 60 |
|
|
|
|
| 61 |
return (
|
| 62 |
biomed_model,
|
| 63 |
biomed_preprocess,
|
|
@@ -65,8 +74,8 @@ def load_models():
|
|
| 65 |
remote_model,
|
| 66 |
remote_preprocess,
|
| 67 |
remote_tokenizer,
|
| 68 |
-
|
| 69 |
-
|
| 70 |
blip_processor,
|
| 71 |
blip_model
|
| 72 |
)
|
|
@@ -79,15 +88,15 @@ def load_models():
|
|
| 79 |
remote_model,
|
| 80 |
remote_preprocess,
|
| 81 |
remote_tokenizer,
|
| 82 |
-
|
| 83 |
-
|
| 84 |
blip_processor,
|
| 85 |
blip_model
|
| 86 |
) = load_models()
|
| 87 |
|
| 88 |
|
| 89 |
# --------------------------------------------------
|
| 90 |
-
#
|
| 91 |
# --------------------------------------------------
|
| 92 |
|
| 93 |
DATASETS = {
|
|
@@ -99,7 +108,7 @@ DATASETS = {
|
|
| 99 |
|
| 100 |
|
| 101 |
# --------------------------------------------------
|
| 102 |
-
# PROMPT TEMPLATES
|
| 103 |
# --------------------------------------------------
|
| 104 |
|
| 105 |
templates = {
|
|
@@ -291,6 +300,7 @@ if uploaded_file:
|
|
| 291 |
|
| 292 |
similarity = (image_features @ text_features.T).softmax(dim=-1)
|
| 293 |
|
|
|
|
| 294 |
elif dataset_key == "satellite":
|
| 295 |
|
| 296 |
img = remote_preprocess(image).unsqueeze(0).to(device)
|
|
@@ -302,9 +312,10 @@ if uploaded_file:
|
|
| 302 |
|
| 303 |
similarity = (image_features @ text_features.T).softmax(dim=-1)
|
| 304 |
|
|
|
|
| 305 |
elif dataset_key == "agriculture":
|
| 306 |
|
| 307 |
-
inputs =
|
| 308 |
text=text_queries,
|
| 309 |
images=image,
|
| 310 |
return_tensors="pt",
|
|
@@ -312,7 +323,7 @@ if uploaded_file:
|
|
| 312 |
).to(device)
|
| 313 |
|
| 314 |
with torch.no_grad():
|
| 315 |
-
outputs =
|
| 316 |
|
| 317 |
similarity = outputs.logits_per_image.softmax(dim=1)
|
| 318 |
|
|
@@ -326,7 +337,7 @@ if uploaded_file:
|
|
| 326 |
|
| 327 |
|
| 328 |
# --------------------------------------------------
|
| 329 |
-
# BLIP CAPTION
|
| 330 |
# --------------------------------------------------
|
| 331 |
|
| 332 |
blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
|
|
|
|
| 4 |
import open_clip
|
| 5 |
|
| 6 |
from transformers import (
|
| 7 |
+
CLIPProcessor,
|
| 8 |
+
CLIPModel,
|
| 9 |
BlipProcessor,
|
| 10 |
+
BlipForConditionalGeneration
|
|
|
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
+
st.set_page_config(page_title="Multi-Domain Zero Shot AI", layout="wide")
|
| 14 |
|
| 15 |
+
st.title("Multi-Domain Zero Shot Image Classification")
|
| 16 |
+
st.write("BiomedCLIP + RemoteCLIP + CLIP + BLIP")
|
| 17 |
|
| 18 |
device = "cpu"
|
| 19 |
|
|
|
|
| 25 |
@st.cache_resource
|
| 26 |
def load_models():
|
| 27 |
|
| 28 |
+
# ---------- BIOMEDCLIP ----------
|
| 29 |
biomed_model, _, biomed_preprocess = open_clip.create_model_and_transforms(
|
| 30 |
"hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
| 31 |
)
|
|
|
|
| 36 |
|
| 37 |
biomed_model = biomed_model.to(device).eval()
|
| 38 |
|
| 39 |
+
|
| 40 |
+
# ---------- REMOTECLIP ----------
|
| 41 |
remote_model, _, remote_preprocess = open_clip.create_model_and_transforms(
|
| 42 |
"ViT-B-32",
|
| 43 |
pretrained="laion2b_s34b_b79k"
|
|
|
|
| 46 |
remote_tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
| 47 |
remote_model = remote_model.to(device).eval()
|
| 48 |
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
# ---------- CLIP (AGRICULTURE) ----------
|
| 51 |
+
clip_model = CLIPModel.from_pretrained(
|
| 52 |
+
"openai/clip-vit-base-patch32"
|
| 53 |
+
).to(device).eval()
|
| 54 |
+
|
| 55 |
+
clip_processor = CLIPProcessor.from_pretrained(
|
| 56 |
+
"openai/clip-vit-base-patch32"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------- BLIP ----------
|
| 61 |
blip_processor = BlipProcessor.from_pretrained(
|
| 62 |
"Salesforce/blip-image-captioning-base"
|
| 63 |
)
|
|
|
|
| 66 |
"Salesforce/blip-image-captioning-base"
|
| 67 |
).to(device).eval()
|
| 68 |
|
| 69 |
+
|
| 70 |
return (
|
| 71 |
biomed_model,
|
| 72 |
biomed_preprocess,
|
|
|
|
| 74 |
remote_model,
|
| 75 |
remote_preprocess,
|
| 76 |
remote_tokenizer,
|
| 77 |
+
clip_model,
|
| 78 |
+
clip_processor,
|
| 79 |
blip_processor,
|
| 80 |
blip_model
|
| 81 |
)
|
|
|
|
| 88 |
remote_model,
|
| 89 |
remote_preprocess,
|
| 90 |
remote_tokenizer,
|
| 91 |
+
clip_model,
|
| 92 |
+
clip_processor,
|
| 93 |
blip_processor,
|
| 94 |
blip_model
|
| 95 |
) = load_models()
|
| 96 |
|
| 97 |
|
| 98 |
# --------------------------------------------------
|
| 99 |
+
# DATASET CLASSES
|
| 100 |
# --------------------------------------------------
|
| 101 |
|
| 102 |
DATASETS = {
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
# --------------------------------------------------
|
| 111 |
+
# PROMPT TEMPLATES (SUBCLASS PROMPTS)
|
| 112 |
# --------------------------------------------------
|
| 113 |
|
| 114 |
templates = {
|
|
|
|
| 300 |
|
| 301 |
similarity = (image_features @ text_features.T).softmax(dim=-1)
|
| 302 |
|
| 303 |
+
|
| 304 |
elif dataset_key == "satellite":
|
| 305 |
|
| 306 |
img = remote_preprocess(image).unsqueeze(0).to(device)
|
|
|
|
| 312 |
|
| 313 |
similarity = (image_features @ text_features.T).softmax(dim=-1)
|
| 314 |
|
| 315 |
+
|
| 316 |
elif dataset_key == "agriculture":
|
| 317 |
|
| 318 |
+
inputs = clip_processor(
|
| 319 |
text=text_queries,
|
| 320 |
images=image,
|
| 321 |
return_tensors="pt",
|
|
|
|
| 323 |
).to(device)
|
| 324 |
|
| 325 |
with torch.no_grad():
|
| 326 |
+
outputs = clip_model(**inputs)
|
| 327 |
|
| 328 |
similarity = outputs.logits_per_image.softmax(dim=1)
|
| 329 |
|
|
|
|
| 337 |
|
| 338 |
|
| 339 |
# --------------------------------------------------
|
| 340 |
+
# BLIP IMAGE CAPTION
|
| 341 |
# --------------------------------------------------
|
| 342 |
|
| 343 |
blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
|