prithivMLmods commited on
Commit
74593d4
·
verified ·
1 Parent(s): e867be7

upload app

Browse files
Files changed (1) hide show
  1. app.py +215 -0
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ import random
6
+ from PIL import Image, ImageDraw
7
+ from typing import Iterable
8
+ from gradio.themes import Soft
9
+ from gradio.themes.utils import colors, fonts, sizes
10
+ from transformers import Sam3Processor, Sam3Model
11
+
12
+ # --- Handle optional 'spaces' import for local compatibility ---
13
+ try:
14
+ import spaces
15
+ except ImportError:
16
+ class spaces:
17
+ @staticmethod
18
+ def GPU(duration=60):
19
+ def decorator(func):
20
+ return func
21
+ return decorator
22
+
23
+ # --- Custom Theme Setup (Plum) ---
24
+ colors.plum = colors.Color(
25
+ name="plum",
26
+ c50="#FDF4FD",
27
+ c100="#F7E6F7",
28
+ c200="#ECD0EC",
29
+ c300="#DDA0DD", # Plum
30
+ c400="#C98BC9",
31
+ c500="#B060B0",
32
+ c600="#964B96",
33
+ c700="#7A3A7A",
34
+ c800="#602C60",
35
+ c900="#451E45",
36
+ c950="#2B122B",
37
+ )
38
+
39
+ class PlumTheme(Soft):
40
+ def __init__(
41
+ self,
42
+ *,
43
+ primary_hue: colors.Color | str = colors.plum,
44
+ secondary_hue: colors.Color | str = colors.plum,
45
+ neutral_hue: colors.Color | str = colors.slate,
46
+ text_size: sizes.Size | str = sizes.text_lg,
47
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
48
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
49
+ ),
50
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
51
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
52
+ ),
53
+ ):
54
+ super().__init__(
55
+ primary_hue=primary_hue,
56
+ secondary_hue=secondary_hue,
57
+ neutral_hue=neutral_hue,
58
+ text_size=text_size,
59
+ font=font,
60
+ font_mono=font_mono,
61
+ )
62
+ self.set(
63
+ background_fill_primary="*primary_50",
64
+ background_fill_primary_dark="*primary_900",
65
+ body_background_fill="linear-gradient(135deg, *primary_100, *primary_50)",
66
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
67
+ button_primary_text_color="white",
68
+ button_primary_text_color_hover="white",
69
+ button_primary_background_fill="linear-gradient(90deg, *primary_500, *primary_600)",
70
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_600, *primary_700)",
71
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *primary_800)",
72
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
73
+ button_secondary_text_color="black",
74
+ button_secondary_text_color_hover="white",
75
+ button_secondary_background_fill="linear-gradient(90deg, *primary_200, *primary_200)",
76
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
77
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
78
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
79
+ slider_color="*primary_500",
80
+ slider_color_dark="*primary_600",
81
+ block_title_text_weight="600",
82
+ block_border_width="3px",
83
+ block_shadow="*shadow_drop_lg",
84
+ button_primary_shadow="*shadow_drop_lg",
85
+ button_large_padding="11px",
86
+ color_accent_soft="*primary_100",
87
+ block_label_background_fill="*primary_200",
88
+ )
89
+
90
+ plum_theme = PlumTheme()
91
+
92
+ # --- Hardware Setup ---
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+ print(f"Using device: {device}")
95
+
96
+ # --- Model Loading ---
97
+ # Using the facebook/sam3 model as requested
98
+ try:
99
+ print("Loading SAM3 Model and Processor...")
100
+ model = Sam3Model.from_pretrained("facebook/sam3").to(device)
101
+ processor = Sam3Processor.from_pretrained("facebook/sam3")
102
+ print("Model loaded successfully.")
103
+ except Exception as e:
104
+ print(f"Error loading model: {e}")
105
+ print("Ensure you have the correct libraries installed and access to the model.")
106
+ # Fallback/Placeholder for demonstration if model doesn't exist in environment yet
107
+ model = None
108
+ processor = None
109
+
110
+ @spaces.GPU(duration=60)
111
+ def segment_image(input_image, text_prompt, threshold=0.5):
112
+ if input_image is None:
113
+ raise gr.Error("Please upload an image.")
114
+ if not text_prompt:
115
+ raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
116
+
117
+ if model is None or processor is None:
118
+ raise gr.Error("Model not loaded correctly.")
119
+
120
+ # Convert image to RGB
121
+ image_pil = input_image.convert("RGB")
122
+
123
+ # Preprocess
124
+ inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
125
+
126
+ # Inference
127
+ with torch.no_grad():
128
+ outputs = model(**inputs)
129
+
130
+ # Post-process results
131
+ results = processor.post_process_instance_segmentation(
132
+ outputs,
133
+ threshold=threshold,
134
+ mask_threshold=0.5,
135
+ target_sizes=inputs.get("original_sizes").tolist()
136
+ )[0]
137
+
138
+ masks = results['masks'] # Boolean tensor [N, H, W]
139
+ scores = results['scores']
140
+
141
+ # Prepare for Gradio AnnotatedImage
142
+ # Gradio expects (image, [(mask, label), ...])
143
+
144
+ annotations = []
145
+ masks_np = masks.cpu().numpy()
146
+ scores_np = scores.cpu().numpy()
147
+
148
+ for i, mask in enumerate(masks_np):
149
+ # mask is a boolean array (True/False).
150
+ # AnnotatedImage handles the coloring automatically.
151
+ # We just pass the mask and a label.
152
+ score_val = scores_np[i]
153
+ label = f"{text_prompt} ({score_val:.2f})"
154
+ annotations.append((mask, label))
155
+
156
+ # Return tuple format for AnnotatedImage
157
+ return (image_pil, annotations)
158
+
159
+ css="""
160
+ #col-container {
161
+ margin: 0 auto;
162
+ max-width: 980px;
163
+ }
164
+ #main-title h1 {font-size: 2.1em !important;}
165
+ """
166
+
167
+ with gr.Blocks(css=css, theme=plum_theme) as demo:
168
+ with gr.Column(elem_id="col-container"):
169
+ gr.Markdown(
170
+ "# **SAM3 Image Segmentation**",
171
+ elem_id="main-title"
172
+ )
173
+
174
+ gr.Markdown("Segment objects in images using **SAM3** (Segment Anything Model 3) with text prompts.")
175
+
176
+ with gr.Row():
177
+ # Left Column: Inputs
178
+ with gr.Column(scale=1):
179
+ input_image = gr.Image(label="Input Image", type="pil", height=350)
180
+ text_prompt = gr.Textbox(
181
+ label="Text Prompt",
182
+ placeholder="e.g., cat, ear, car wheel...",
183
+ info="What do you want to segment?"
184
+ )
185
+ threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
186
+
187
+ run_button = gr.Button("Segment", variant="primary")
188
+
189
+ # Right Column: Output
190
+ with gr.Column(scale=1.5):
191
+ # AnnotatedImage creates a nice overlay visualization
192
+ output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
193
+
194
+ # Examples
195
+ gr.Examples(
196
+ examples=[
197
+ ["examples/cat.jpg", "cat", 0.5],
198
+ ["examples/car.jpg", "tire", 0.4],
199
+ ["examples/fruit.jpg", "apple", 0.5],
200
+ ],
201
+ inputs=[input_image, text_prompt, threshold],
202
+ outputs=[output_image],
203
+ fn=segment_image,
204
+ cache_examples=False,
205
+ label="Examples (Ensure files exist in 'examples/' folder)"
206
+ )
207
+
208
+ run_button.click(
209
+ fn=segment_image,
210
+ inputs=[input_image, text_prompt, threshold],
211
+ outputs=[output_image]
212
+ )
213
+
214
+ if __name__ == "__main__":
215
+ demo.launch(ssr_mode=False, show_error=True)