File size: 3,562 Bytes
8f08648
 
 
 
 
 
 
8986db1
8f08648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps
import requests
from io import BytesIO
import os
from env import UPLOADED_IMAGES_DIR

# Load the model from TF Hub
# Cache the model globally
detector = hub.load("https://tfhub.dev/google/openimages_v4/ssd/mobilenet_v2/1").signatures['default']

# Classes you care about
TARGET_CLASSES = set(["Food processor", "Fast food", "Food", "Seafood", "Snack"])


def load_image_from_url(url, size=(640, 480)):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    img = ImageOps.fit(img, size, Image.Resampling.LANCZOS)
    return img


def run_object_detection(image: Image.Image):
    image_np = np.array(image)
    # Convert to tensor without specifying dtype
    input_tensor = tf.convert_to_tensor(image_np)[tf.newaxis, ...]
    # Convert to float32 and normalize to [0,1]
    input_tensor = tf.cast(input_tensor, tf.float32) / 255.0
    results = detector(input_tensor)
    results = {k: v.numpy() for k, v in results.items()}
    return results, image_np

def get_filtered_class_boxes(results):
    # for same class, keep the one with the highest score
    # and remove duplicates
    boxes = []
    classes = []
    scores = []

    for i in range(len(results["detection_scores"])):
        class_name = results["detection_class_entities"][i].decode("utf-8")
        box = results["detection_boxes"][i]
        score = results["detection_scores"][i]
        if class_name in TARGET_CLASSES:
            if class_name not in classes:
                boxes.append(box)
                classes.append(class_name)
                scores.append(score)
            else:
                index = classes.index(class_name)
                if score > scores[index]:
                    boxes[index] = box
                    classes[index] = class_name
                    scores[index] = score
    return boxes, classes, scores


def crop_and_save(image_np, boxes, class_names, scores, min_score=0.3):
    cropped_images = []
    for i in range(len(scores)):
        if scores[i] > min_score:
            ymin, xmin, ymax, xmax = boxes[i]
            im_width, im_height = image_np.shape[1], image_np.shape[0]
            (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                          ymin * im_height, ymax * im_height)
            cropped_image = image_np[int(top):int(bottom), int(left):int(right)]
            cropped_images.append((cropped_image, class_names[i], scores[i]))
            # Save the cropped image
            pil_image = Image.fromarray(cropped_image)
            pil_image.save(os.path.join(UPLOADED_IMAGES_DIR, f"{class_names[i]}_{scores[i]:.2f}.jpg"))
    return cropped_images

def draw_boxes(image_np, boxes, class_names, scores, min_score=0.3):
    image_pil = Image.fromarray(image_np)
    draw = ImageDraw.Draw(image_pil)
    font = ImageFont.load_default()

    for i in range(len(scores)):
        label = class_names[i]
        if label in TARGET_CLASSES and scores[i] > min_score:
            ymin, xmin, ymax, xmax = boxes[i]
            im_width, im_height = image_pil.size
            (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                          ymin * im_height, ymax * im_height)
            draw.rectangle([left, top, right, bottom], outline="red", width=2)
            draw.text((left, top), f"{label}: {scores[i]*100:.1f}%", fill="white", font=font)
    return image_pil