kubrabuzlu's picture
Add application file
d6fd25b
from flask import Flask, request, jsonify
from flask_restx import Api, Resource, fields
from PIL import Image
import numpy as np
import io
import base64
import pandas as pd
import torch
from transformers import DPTImageProcessor, DPTForDepthEstimation, SegformerImageProcessor, SegformerForSemanticSegmentation
import json
id2label = {}
with open('id2label.json', 'r') as json_file:
id2label = json.load(json_file)
app = Flask(__name__)
api = Api(app, version='1.0', title='Food Segmentation API', description='API for food segmentation and weight estimation')
# Define the image payload model
image_payload = api.model('ImagePayload', {
'image': fields.String(required=True, description='Base64 encoded image')
})
# Initialize models
depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas", low_cpu_mem_usage=True)
seg_processor = SegformerImageProcessor.from_pretrained("prem-timsina/segformer-b0-finetuned-food")
seg_model = SegformerForSemanticSegmentation.from_pretrained("prem-timsina/segformer-b0-finetuned-food")
# Load density database
density_db_path = 'density_table_foodseg103.xlsx'
density_db = pd.read_excel(density_db_path, sheet_name=0, usecols=[1, 2])
density_db.dropna(inplace=True)
def get_density(food_type):
match = density_db[density_db['Food'].str.contains(food_type, case=False, na=False)]
if not match.empty:
return match.iloc[0]['Density in g/ml']
return 0 # Default density if not found
def estimate_depth(image):
inputs = depth_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
return prediction.squeeze().cpu().numpy()
def segment_image(image):
inputs = seg_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = seg_model(**inputs)
logits = outputs.logits
return torch.argmax(logits, dim=1).squeeze().cpu().numpy()
def calculate_volume(depth_map, mask):
if depth_map.shape != mask.shape:
depth_map = np.resize(depth_map, mask.shape)
masked_depth = depth_map * mask
volume = round(np.sum(masked_depth) * (0.0001),2)
return volume
@api.route('/process_image')
class ProcessImage(Resource):
@api.expect(image_payload)
def post(self):
data = request.json
base64_image = data['image']
# Decode the image
image_data = base64.b64decode(base64_image)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Estimate depth and segment image
depth_map = estimate_depth(image)
segmentation_mask = segment_image(image)
# Ensure depth_map and segmentation_mask have the same dimensions
if depth_map.shape != segmentation_mask.shape:
depth_map = np.resize(depth_map, segmentation_mask.shape)
# Calculate volume and weight for each class
unique_classes = np.unique(segmentation_mask)
results = []
for class_id in unique_classes:
mask = (segmentation_mask == class_id).astype(np.uint8)
volume = calculate_volume(depth_map, mask)
class_name = id2label[str(class_id)]
density = get_density(class_name)
weight = round(volume * density,2)
results.append({
'class_id': int(class_id),
'class_name': class_name,
'estimated_weight_g': weight
})
return jsonify(results)
#if __name__ == "__main__":
# app.run(host='0.0.0.0', port=5000)