prithivMLmods commited on
Commit
fe17f1e
·
verified ·
1 Parent(s): 30007b5

update app

Browse files
Files changed (1) hide show
  1. app.py +60 -53
app.py CHANGED
@@ -28,7 +28,7 @@ class OrangeRedTheme(Soft):
28
  self,
29
  *,
30
  primary_hue: colors.Color | str = colors.gray,
31
- secondary_hue: colors.Color | str = colors.orange_red, # Use the new color
32
  neutral_hue: colors.Color | str = colors.slate,
33
  text_size: sizes.Size | str = sizes.text_lg,
34
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -55,80 +55,87 @@ class OrangeRedTheme(Soft):
55
  button_primary_text_color_hover="white",
56
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
57
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
58
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
59
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
60
- button_secondary_text_color="black",
61
- button_secondary_text_color_hover="white",
62
- button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
63
- button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
64
- button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
65
- button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
66
- slider_color="*secondary_500",
67
- slider_color_dark="*secondary_600",
68
  block_title_text_weight="600",
69
- block_border_width="3px",
70
  block_shadow="*shadow_drop_lg",
71
- button_primary_shadow="*shadow_drop_lg",
72
- button_large_padding="11px",
73
- color_accent_soft="*primary_100",
74
- block_label_background_fill="*primary_200",
75
  )
76
 
77
  orange_red_theme = OrangeRedTheme()
78
 
79
- model = AutoModel.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16", torch_dtype=torch.bfloat16, attn_implementation="sdpa")
 
 
 
 
80
  processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")
81
 
82
  def postprocess_metaclip(probs, labels):
83
- output = {labels[i]: probs[0][i].item() for i in range(len(labels))}
84
- return output
85
-
86
 
87
  def metaclip_detector(image, texts):
88
  inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
89
  with torch.no_grad():
90
  outputs = model(**inputs)
91
- logits_per_image = outputs.logits_per_image
92
- probs = logits_per_image.softmax(dim=1)
93
  return probs
94
 
95
-
96
  def infer(image, candidate_labels):
97
- candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
98
  probs = metaclip_detector(image, candidate_labels)
99
  return postprocess_metaclip(probs, labels=candidate_labels)
100
 
101
- css="""
102
- #col-container {
103
- margin: 0 auto;
104
- max-width: 960px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  }
106
- #main-title h1 {font-size: 2.1em !important;}
107
  """
108
 
109
  with gr.Blocks(css=css, theme=orange_red_theme) as demo:
110
- gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**")
111
- gr.Markdown(
112
- "The demo of MetaCLIP 2 for zero-shot classification in this Space."
113
- )
114
- with gr.Row():
115
- with gr.Column():
116
- image_input = gr.Image(type="pil")
117
- text_input = gr.Textbox(label="Input a list of labels (comma seperated)")
118
- run_button = gr.Button("Run", variant="primary")
119
- with gr.Column():
120
- metaclip_output = gr.Label(label="MetaCLIP 2 Output", num_top_classes=3)
121
-
122
- examples = [
123
- ["./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"],
124
- ["./cat.jpg", "a cat, two cats, three cats"],
125
- ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
126
- ]
127
- gr.Examples(
128
- examples=examples,
129
- inputs=[image_input, text_input],
130
- outputs=[metaclip_output],
131
- fn=infer,
132
- )
133
- run_button.click(fn=infer, inputs=[image_input, text_input], outputs=[metaclip_output])
 
 
 
 
 
 
 
 
 
134
  demo.launch()
 
28
  self,
29
  *,
30
  primary_hue: colors.Color | str = colors.gray,
31
+ secondary_hue: colors.Color | str = colors.orange_red,
32
  neutral_hue: colors.Color | str = colors.slate,
33
  text_size: sizes.Size | str = sizes.text_lg,
34
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
55
  button_primary_text_color_hover="white",
56
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
57
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
 
 
 
 
 
 
 
 
 
 
58
  block_title_text_weight="600",
 
59
  block_shadow="*shadow_drop_lg",
 
 
 
 
60
  )
61
 
62
  orange_red_theme = OrangeRedTheme()
63
 
64
+ model = AutoModel.from_pretrained(
65
+ "facebook/metaclip-2-mt5-worldwide-s16",
66
+ torch_dtype=torch.bfloat16,
67
+ attn_implementation="sdpa"
68
+ )
69
  processor = AutoProcessor.from_pretrained("facebook/metaclip-2-mt5-worldwide-s16")
70
 
71
  def postprocess_metaclip(probs, labels):
72
+ return {labels[i]: probs[0][i].item() for i in range(len(labels))}
 
 
73
 
74
  def metaclip_detector(image, texts):
75
  inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
76
  with torch.no_grad():
77
  outputs = model(**inputs)
78
+ probs = outputs.logits_per_image.softmax(dim=1)
 
79
  return probs
80
 
 
81
  def infer(image, candidate_labels):
82
+ candidate_labels = [l.strip() for l in candidate_labels.split(",")]
83
  probs = metaclip_detector(image, candidate_labels)
84
  return postprocess_metaclip(probs, labels=candidate_labels)
85
 
86
+ css = """
87
+ #root, body, html {
88
+ margin: 0;
89
+ padding: 0;
90
+ height: 100%;
91
+ }
92
+
93
+ .center-container {
94
+ max-width: 900px;
95
+ margin: 0 auto !important;
96
+ display: flex;
97
+ flex-direction: column;
98
+ align-items: center;
99
+ }
100
+
101
+ #main-title h1 {
102
+ text-align: center !important;
103
+ width: 100%;
104
  }
 
105
  """
106
 
107
  with gr.Blocks(css=css, theme=orange_red_theme) as demo:
108
+ with gr.Column(elem_classes="center-container"):
109
+
110
+ gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**", elem_id="main-title")
111
+ gr.Markdown("This is the demo of MetaCLIP 2 for zero-shot classification.")
112
+
113
+ with gr.Row():
114
+ with gr.Column():
115
+ image_input = gr.Image(type="pil", label="Upload Image")
116
+ text_input = gr.Textbox(label="Input labels (comma separated)")
117
+ run_button = gr.Button("Run", variant="primary")
118
+ with gr.Column():
119
+ metaclip_output = gr.Label(
120
+ label="MetaCLIP 2 Output",
121
+ num_top_classes=3
122
+ )
123
+
124
+ gr.Examples(
125
+ examples=[
126
+ ["./baklava.jpg", "dessert on a plate, baklava"],
127
+ ["./cat.jpg", "a cat, two cats, three cats"],
128
+ ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
129
+ ],
130
+ inputs=[image_input, text_input],
131
+ outputs=[metaclip_output],
132
+ fn=infer,
133
+ )
134
+
135
+ run_button.click(
136
+ fn=infer,
137
+ inputs=[image_input, text_input],
138
+ outputs=[metaclip_output]
139
+ )
140
+
141
  demo.launch()