nikkoyabut commited on
Commit
c71972d
·
verified ·
1 Parent(s): 80898b0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ # 🛠️ Setup
4
+ # pip install -q gradio torch ftfy regex tqdm git+https://github.com/openai/CLIP.git matplotlib
5
+
6
+ # 📦 Imports
7
+ import gradio as gr
8
+ import torch
9
+ import clip
10
+ from PIL import Image
11
+ import numpy as np
12
+ from typing import List, Tuple, Union
13
+
14
+ # 🚀 Load CLIP Model
15
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model, preprocess = clip.load("ViT-B/32", device=device)
17
+
18
+
19
+ # def print_installed_packages():
20
+ # installed_packages = pip.get_installed_distributions()
21
+ # for package in installed_packages:
22
+ # print(f"{package.project_name}=={package.version}")
23
+
24
+ def predict(image: Image.Image, label_text: str) -> List[List[Union[str, float]]]:
25
+ """
26
+ Perform zero-shot classification using the CLIP model.
27
+
28
+ Args:
29
+ image (PIL.Image.Image): Input image.
30
+ label_text (str): Comma-separated labels to classify against.
31
+
32
+ Returns:
33
+ List[List[Union[str, float]]]: A list of results with label, probability, and confidence bar HTML.
34
+ """
35
+ labels: List[str] = [label.strip() for label in label_text.split(",") if label.strip()]
36
+ if not image or not labels:
37
+ return []
38
+
39
+ # Preprocess inputs
40
+ image_input: torch.Tensor = preprocess(image).unsqueeze(0).to(device)
41
+ text_inputs: torch.Tensor = clip.tokenize(labels).to(device)
42
+
43
+ # Run model
44
+ with torch.no_grad():
45
+ image_features: torch.Tensor = model.encode_image(image_input)
46
+ text_features: torch.Tensor = model.encode_text(text_inputs)
47
+ logits_per_image, _ = model(image_input, text_inputs)
48
+ probs: np.ndarray = logits_per_image.softmax(dim=-1).cpu().numpy()[0]
49
+
50
+ # Create table with bar visualization
51
+ results: List[List[Union[str, float]]] = []
52
+ for label, prob in zip(labels, probs):
53
+ bar_html: str = (
54
+ f'<div style="background-color:#4caf50;width:{prob * 100:.1f}%;height:20px;"></div>'
55
+ )
56
+ results.append([label, f"{prob * 100:.2f}%", bar_html])
57
+
58
+ return results
59
+
60
+
61
+ # 🎨 Gradio Interface
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("## CLIP Zero-Shot Classifier")
64
+
65
+ with gr.Row():
66
+ image = gr.Image(type="pil", label="Upload Image")
67
+ label_text = gr.Textbox(
68
+ lines=2,
69
+ label="Enter comma-separated labels",
70
+ placeholder="e.g., a cat, a dog, a diagram"
71
+ )
72
+
73
+ # Image Examples
74
+ with gr.Row():
75
+ gr.Examples(
76
+ examples=[
77
+ ["images/boy.jpg"],
78
+ ["images/dog.jpg"],
79
+ ["images/boy_dog.jpg"]
80
+ ],
81
+ inputs=[image],
82
+ label="🖼️ Click to select example image"
83
+ )
84
+
85
+ # Label Text Examples
86
+ gr.Examples(
87
+ examples=[
88
+ ["boy, girl, dog, cat"],
89
+ ["a boy with a dog, a boy with a cat, a girl with a dog, a girl with a cat"],
90
+ ["a cat, a dog, a diagram"]
91
+ ],
92
+ inputs=[label_text],
93
+ label="📝 Click to autofill example labels"
94
+ )
95
+
96
+ submit = gr.Button("Classify")
97
+
98
+ output = gr.Dataframe(
99
+ headers=["Label", "Probability", "Confidence Bar"],
100
+ datatype=["str", "str", "html"],
101
+ row_count=5,
102
+ interactive=False
103
+ )
104
+
105
+ submit.click(fn=predict, inputs=[image, label_text], outputs=output)
106
+
107
+ if __name__ == "__main__":
108
+ # print_installed_packages()
109
+ demo.launch(share=True)