Update tasks/image.py
Browse files- tasks/image.py +22 -4
tasks/image.py
CHANGED
|
@@ -13,7 +13,6 @@ from PIL import Image
|
|
| 13 |
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
|
| 14 |
import cv2
|
| 15 |
from tqdm import tqdm
|
| 16 |
-
from dataset import WildfireSmokeDataset
|
| 17 |
from torch.utils.data import DataLoader
|
| 18 |
|
| 19 |
from dotenv import load_dotenv
|
|
@@ -30,6 +29,19 @@ model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobile
|
|
| 30 |
model.load_state_dict(torch.load(model_path))
|
| 31 |
model.eval()
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def get_bounding_boxes_from_mask(mask):
|
| 34 |
"""Extract bounding boxes from a binary mask."""
|
| 35 |
pred_boxes = []
|
|
@@ -39,7 +51,7 @@ def get_bounding_boxes_from_mask(mask):
|
|
| 39 |
x, y, w, h = cv2.boundingRect(contour)
|
| 40 |
pred_boxes.append((x, y, x + w, y + h))
|
| 41 |
return pred_boxes
|
| 42 |
-
|
| 43 |
def parse_boxes(annotation_string):
|
| 44 |
"""Parse multiple boxes from a single annotation string.
|
| 45 |
Each box has 5 values: class_id, x_center, y_center, width, height"""
|
|
@@ -130,6 +142,10 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
| 130 |
for example in test_dataset:
|
| 131 |
# Extract image and annotations
|
| 132 |
image = example["image"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
annotation = example.get("annotations", "").strip()
|
| 134 |
|
| 135 |
|
|
@@ -154,8 +170,10 @@ async def evaluate_image(request: ImageEvaluationRequest):
|
|
| 154 |
|
| 155 |
probabilities = torch.sigmoid(logits)
|
| 156 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
| 157 |
-
predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
# Extract predicted bounding boxes
|
| 160 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
| 161 |
pred_boxes.append(predicted_boxes)
|
|
|
|
| 13 |
from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation
|
| 14 |
import cv2
|
| 15 |
from tqdm import tqdm
|
|
|
|
| 16 |
from torch.utils.data import DataLoader
|
| 17 |
|
| 18 |
from dotenv import load_dotenv
|
|
|
|
| 29 |
model.load_state_dict(torch.load(model_path))
|
| 30 |
model.eval()
|
| 31 |
|
| 32 |
+
def preprocess(image):
|
| 33 |
+
image = image.resize((512,512))
|
| 34 |
+
|
| 35 |
+
# Convert to BGR
|
| 36 |
+
image = np.array(image)[:, :, ::-1] # Convert RGB to BGR
|
| 37 |
+
image = Image.fromarray(image)
|
| 38 |
+
image = image.resize(self.image_size)
|
| 39 |
+
|
| 40 |
+
# Normalize pixel values to [0, 1]
|
| 41 |
+
image = np.array(image, dtype=np.float32) / 255.0
|
| 42 |
+
|
| 43 |
+
return image
|
| 44 |
+
|
| 45 |
def get_bounding_boxes_from_mask(mask):
|
| 46 |
"""Extract bounding boxes from a binary mask."""
|
| 47 |
pred_boxes = []
|
|
|
|
| 51 |
x, y, w, h = cv2.boundingRect(contour)
|
| 52 |
pred_boxes.append((x, y, x + w, y + h))
|
| 53 |
return pred_boxes
|
| 54 |
+
|
| 55 |
def parse_boxes(annotation_string):
|
| 56 |
"""Parse multiple boxes from a single annotation string.
|
| 57 |
Each box has 5 values: class_id, x_center, y_center, width, height"""
|
|
|
|
| 142 |
for example in test_dataset:
|
| 143 |
# Extract image and annotations
|
| 144 |
image = example["image"]
|
| 145 |
+
|
| 146 |
+
original_shape = (len(image), len(image[0]))
|
| 147 |
+
image = preprocess(image)
|
| 148 |
+
|
| 149 |
annotation = example.get("annotations", "").strip()
|
| 150 |
|
| 151 |
|
|
|
|
| 170 |
|
| 171 |
probabilities = torch.sigmoid(logits)
|
| 172 |
predicted_mask = (probabilities[0, 1] > 0.30).cpu().numpy().astype(np.uint8)
|
| 173 |
+
# predicted_mask_resized = cv2.resize(predicted_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
|
| 174 |
+
predicted_mask_resized = cv2.resize(predicted_mask, original_shape, interpolation=cv2.INTER_NEAREST)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
# Extract predicted bounding boxes
|
| 178 |
predicted_boxes = get_bounding_boxes_from_mask(predicted_mask_resized)
|
| 179 |
pred_boxes.append(predicted_boxes)
|