File size: 3,846 Bytes
d6fd25b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from flask import Flask, request, jsonify
from flask_restx import Api, Resource, fields
from PIL import Image
import numpy as np
import io
import base64
import pandas as pd
import torch
from transformers import DPTImageProcessor, DPTForDepthEstimation, SegformerImageProcessor, SegformerForSemanticSegmentation
import json

id2label = {}
with open('id2label.json', 'r') as json_file:
    id2label = json.load(json_file)


app = Flask(__name__)
api = Api(app, version='1.0', title='Food Segmentation API', description='API for food segmentation and weight estimation')

# Define the image payload model
image_payload = api.model('ImagePayload', {
    'image': fields.String(required=True, description='Base64 encoded image')
})

# Initialize models
depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas", low_cpu_mem_usage=True)

seg_processor = SegformerImageProcessor.from_pretrained("prem-timsina/segformer-b0-finetuned-food")
seg_model = SegformerForSemanticSegmentation.from_pretrained("prem-timsina/segformer-b0-finetuned-food")

# Load density database
density_db_path = 'density_table_foodseg103.xlsx'
density_db = pd.read_excel(density_db_path, sheet_name=0, usecols=[1, 2])
density_db.dropna(inplace=True)



def get_density(food_type):
    match = density_db[density_db['Food'].str.contains(food_type, case=False, na=False)]
    if not match.empty:
        return match.iloc[0]['Density in g/ml']
    return 0  # Default density if not found

def estimate_depth(image):
    inputs = depth_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = depth_model(**inputs)
        predicted_depth = outputs.predicted_depth
    prediction = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=image.size[::-1],
        mode="bicubic",
        align_corners=False,
    )
    return prediction.squeeze().cpu().numpy()

def segment_image(image):
    inputs = seg_processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = seg_model(**inputs)
        logits = outputs.logits
    return torch.argmax(logits, dim=1).squeeze().cpu().numpy()

def calculate_volume(depth_map, mask):
    if depth_map.shape != mask.shape:
        depth_map = np.resize(depth_map, mask.shape)
    masked_depth = depth_map * mask
    volume = round(np.sum(masked_depth) * (0.0001),2)
    return volume

@api.route('/process_image')
class ProcessImage(Resource):
    @api.expect(image_payload)
    def post(self):
        data = request.json
        base64_image = data['image']
        
        # Decode the image
        image_data = base64.b64decode(base64_image)
        image = Image.open(io.BytesIO(image_data)).convert("RGB")

        # Estimate depth and segment image
        depth_map = estimate_depth(image)
        segmentation_mask = segment_image(image)

        # Ensure depth_map and segmentation_mask have the same dimensions
        if depth_map.shape != segmentation_mask.shape:
            depth_map = np.resize(depth_map, segmentation_mask.shape)

        # Calculate volume and weight for each class
        unique_classes = np.unique(segmentation_mask)
        results = []
        for class_id in unique_classes:
            mask = (segmentation_mask == class_id).astype(np.uint8)
            volume = calculate_volume(depth_map, mask)
            class_name = id2label[str(class_id)]
            density = get_density(class_name)
            weight = round(volume * density,2)
            results.append({
                'class_id': int(class_id),
                'class_name': class_name,
                'estimated_weight_g': weight
            })

        return jsonify(results)

#if __name__ == "__main__":
#    app.run(host='0.0.0.0', port=5000)