rld / utils /image_processing.py
zynt31's picture
Upload 10 files
6b36551 verified
raw
history blame contribute delete
579 Bytes
import torch
from torchvision import transforms
from PIL import Image
import io
# Same transformations used during training
def get_transform():
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def process_image(image_bytes):
"""Process an image from bytes to tensor"""
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
transform = get_transform()
return transform(image).unsqueeze(0) # Add batch dimension