Karthika0308 commited on
Commit
e5b9c2c
·
verified ·
1 Parent(s): 1035f85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -77
app.py CHANGED
@@ -1,42 +1,60 @@
1
  import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3
  import torch
4
  import gradio as gr
5
  from gradio_image_prompter import ImagePrompter
6
- from torch.nn import DataParallel
7
  from models.counter_infer import build_model
8
  from utils.arg_parser import get_argparser
9
  from utils.data import resize_and_pad
10
  import torchvision.ops as ops
11
  from torchvision import transforms as T
12
- from PIL import Image, ImageDraw, ImageFont
13
  import numpy as np
14
 
15
- # Load model (once, to avoid reloading)
 
 
16
  def load_model():
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- args = get_argparser().parse_args()
 
19
  args.zero_shot = True
20
- model = DataParallel(build_model(args).to(device))
21
- model.load_state_dict(torch.load('CNTQG_multitrain_ca44.pth', weights_only=True)['model'], strict=False)
 
 
 
 
 
22
  model.eval()
 
23
  return model, device
24
 
25
  model, device = load_model()
26
 
27
- # **Function to Process Image Once**
 
 
28
  def process_image_once(inputs, enable_mask):
29
- model.module.return_masks = enable_mask
30
 
31
  image = inputs['image']
32
  drawn_boxes = inputs['points']
 
33
  image_tensor = torch.tensor(image).to(device)
34
  image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
35
- image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
36
 
37
- bboxes_tensor = torch.tensor([[box[0], box[1], box[3], box[4]] for box in drawn_boxes], dtype=torch.float32).to(device)
 
 
 
 
 
 
 
 
38
 
39
  img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
 
40
  img = img.unsqueeze(0).to(device)
41
  bboxes = bboxes.unsqueeze(0).to(device)
42
 
@@ -45,69 +63,44 @@ def process_image_once(inputs, enable_mask):
45
 
46
  return image, outputs, masks, img, scale, drawn_boxes
47
 
48
- # **Post-process and Update Output**
 
 
49
  def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
50
  idx = 0
51
- threshold = 1/threshold
52
- keep = ops.nms(outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold],
53
- outputs[idx]['box_v'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold], 0.5)
 
54
 
55
- pred_boxes = outputs[idx]['pred_boxes'][outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold][keep]
 
 
 
 
 
 
 
 
56
  pred_boxes = torch.clamp(pred_boxes, 0, 1)
57
 
58
  pred_boxes = (pred_boxes.cpu() / scale * img.shape[-1]).tolist()
59
 
60
- image = Image.fromarray((image).astype(np.uint8))
61
-
62
- if enable_mask:
63
- from matplotlib import pyplot as plt
64
- masks_ = masks[idx][(outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold)[0]]
65
- N_masks = masks_.shape[0]
66
- indices = torch.randint(1, N_masks + 1, (1, N_masks), device=masks_.device).view(-1, 1, 1)
67
- masks = (masks_ * indices).sum(dim=0)
68
- mask_display = (
69
- T.Resize((int(img.shape[2] / scale), int(img.shape[3] / scale)), interpolation=T.InterpolationMode.NEAREST)(
70
- masks.cpu().unsqueeze(0))[0])[:image.size[1], :image.size[0]]
71
- cmap = plt.cm.tab20
72
- norm = plt.Normalize(vmin=0, vmax=N_masks)
73
- del masks
74
- del masks_
75
- del outputs
76
- rgba_image = cmap(norm(mask_display))
77
- rgba_image[mask_display == 0, -1] = 0
78
- rgba_image[mask_display != 0, -1] = 0.5
79
-
80
- overlay = Image.fromarray((rgba_image * 255).astype(np.uint8), mode="RGBA")
81
- image = image.convert("RGBA")
82
- image = Image.alpha_composite(image, overlay)
83
-
84
 
85
  draw = ImageDraw.Draw(image)
 
86
  for box in pred_boxes:
87
  draw.rectangle([box[0], box[1], box[2], box[3]], outline="orange", width=5)
88
- # for box in drawn_boxes:
89
- # draw.rectangle([box[0], box[1], box[3], box[4]], outline="red", width=3)
90
-
91
- width, height = image.size
92
- square_size = int(0.05 * width)
93
- x1, y1 = 10, height - square_size - 10
94
- x2, y2 = x1 + square_size, y1 + square_size
95
-
96
- # draw.rectangle([x1, y1, x2, y2], outline="black", fill="black", width=1)
97
- # font = ImageFont.load_default()
98
- # txt = str(len(pred_boxes))
99
- # w = draw.textlength(txt, font=font)
100
- # text_x = x1 + (square_size - w) / 2
101
- # text_y = y1 + (square_size - 10) / 2
102
- # draw.text((text_x, text_y), txt, fill="white", font=font)
103
 
104
  return image, len(pred_boxes)
105
 
106
-
 
 
107
  iface = gr.Blocks()
108
 
109
  with iface:
110
- # Store intermediate states
111
  image_input = gr.State()
112
  outputs_state = gr.State()
113
  masks_state = gr.State()
@@ -115,45 +108,34 @@ with iface:
115
  scale_state = gr.State()
116
  drawn_boxes_state = gr.State()
117
 
118
- # UI Layout: Input Section
119
  with gr.Row():
120
  image_prompter = ImagePrompter()
121
  image_output = gr.Image(type="pil")
122
-
123
 
124
- # UI Layout: Output Section
125
  with gr.Row():
126
  count_output = gr.Number(label="Total Count")
127
- enable_mask = gr.Checkbox(label="Predict masks", value=True) # Mask enabled by default
128
- threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01, label="Threshold") # Updated range and default
129
-
130
 
131
- # Create the 'Count' button
132
  count_button = gr.Button("Count")
133
 
134
- # Process image once when "Count" button is pressed
135
  def initial_process(inputs, enable_mask, threshold):
136
- # Perform inference once
137
  image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
138
 
139
- # Save intermediate states
140
  return (
141
- *post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold), # Processed outputs
142
- image, outputs, masks, img, scale, drawn_boxes # Store in states for later use
143
  )
144
 
145
- # Update image and count when the threshold slider changes (post-process only)
146
  def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
147
  return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
148
 
149
- # Run initial inference and post-process when "Count" button is clicked
150
  count_button.click(
151
  initial_process,
152
- [image_prompter, enable_mask, threshold], # Inputs
153
- [image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state] # Outputs + States
154
  )
155
 
156
- # Adjust the output dynamically based on the threshold slider (no re-inference)
157
  threshold.change(
158
  update_threshold,
159
  [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
@@ -161,10 +143,9 @@ with iface:
161
  )
162
 
163
  enable_mask.change(
164
- update_threshold,
165
  [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
166
  [image_output, count_output]
167
  )
168
 
169
- iface.launch(share=True)
170
-
 
1
  import os
 
2
  import torch
3
  import gradio as gr
4
  from gradio_image_prompter import ImagePrompter
 
5
  from models.counter_infer import build_model
6
  from utils.arg_parser import get_argparser
7
  from utils.data import resize_and_pad
8
  import torchvision.ops as ops
9
  from torchvision import transforms as T
10
+ from PIL import Image, ImageDraw
11
  import numpy as np
12
 
13
+ # -----------------------
14
+ # LOAD MODEL
15
+ # -----------------------
16
  def load_model():
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ args = get_argparser().parse_args([])
20
  args.zero_shot = True
21
+
22
+ model = build_model(args).to(device)
23
+
24
+ # ⚠️ Make sure this file exists in your repo
25
+ checkpoint = torch.load("CNTQG_multitrain_ca44.pth", map_location=device)
26
+
27
+ model.load_state_dict(checkpoint["model"], strict=False)
28
  model.eval()
29
+
30
  return model, device
31
 
32
  model, device = load_model()
33
 
34
+ # -----------------------
35
+ # PROCESS IMAGE
36
+ # -----------------------
37
  def process_image_once(inputs, enable_mask):
38
+ model.return_masks = enable_mask # ✅ FIXED
39
 
40
  image = inputs['image']
41
  drawn_boxes = inputs['points']
42
+
43
  image_tensor = torch.tensor(image).to(device)
44
  image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
 
45
 
46
+ image_tensor = T.Normalize(
47
+ mean=[0.485, 0.456, 0.406],
48
+ std=[0.229, 0.224, 0.225]
49
+ )(image_tensor)
50
+
51
+ bboxes_tensor = torch.tensor(
52
+ [[box[0], box[1], box[3], box[4]] for box in drawn_boxes],
53
+ dtype=torch.float32
54
+ ).to(device)
55
 
56
  img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
57
+
58
  img = img.unsqueeze(0).to(device)
59
  bboxes = bboxes.unsqueeze(0).to(device)
60
 
 
63
 
64
  return image, outputs, masks, img, scale, drawn_boxes
65
 
66
+ # -----------------------
67
+ # POST PROCESS
68
+ # -----------------------
69
  def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
70
  idx = 0
71
+ threshold = 1 / threshold
72
+
73
+ scores = outputs[idx]['box_v']
74
+ boxes = outputs[idx]['pred_boxes']
75
 
76
+ keep_mask = scores > scores.max() / threshold
77
+
78
+ keep = ops.nms(
79
+ boxes[keep_mask],
80
+ scores[keep_mask],
81
+ 0.5
82
+ )
83
+
84
+ pred_boxes = boxes[keep_mask][keep]
85
  pred_boxes = torch.clamp(pred_boxes, 0, 1)
86
 
87
  pred_boxes = (pred_boxes.cpu() / scale * img.shape[-1]).tolist()
88
 
89
+ image = Image.fromarray(image.astype(np.uint8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  draw = ImageDraw.Draw(image)
92
+
93
  for box in pred_boxes:
94
  draw.rectangle([box[0], box[1], box[2], box[3]], outline="orange", width=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  return image, len(pred_boxes)
97
 
98
+ # -----------------------
99
+ # GRADIO UI
100
+ # -----------------------
101
  iface = gr.Blocks()
102
 
103
  with iface:
 
104
  image_input = gr.State()
105
  outputs_state = gr.State()
106
  masks_state = gr.State()
 
108
  scale_state = gr.State()
109
  drawn_boxes_state = gr.State()
110
 
 
111
  with gr.Row():
112
  image_prompter = ImagePrompter()
113
  image_output = gr.Image(type="pil")
 
114
 
 
115
  with gr.Row():
116
  count_output = gr.Number(label="Total Count")
117
+ enable_mask = gr.Checkbox(label="Predict masks", value=True)
118
+ threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01)
 
119
 
 
120
  count_button = gr.Button("Count")
121
 
 
122
  def initial_process(inputs, enable_mask, threshold):
 
123
  image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
124
 
 
125
  return (
126
+ *post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
127
+ image, outputs, masks, img, scale, drawn_boxes
128
  )
129
 
 
130
  def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
131
  return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
132
 
 
133
  count_button.click(
134
  initial_process,
135
+ [image_prompter, enable_mask, threshold],
136
+ [image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state]
137
  )
138
 
 
139
  threshold.change(
140
  update_threshold,
141
  [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
 
143
  )
144
 
145
  enable_mask.change(
146
+ update_threshold,
147
  [threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
148
  [image_output, count_output]
149
  )
150
 
151
+ iface.launch()