amaye15
commited on
Commit
·
8d41aec
1
Parent(s):
aea7238
Docstring added
Browse files- handler.py +35 -12
handler.py
CHANGED
|
@@ -66,50 +66,74 @@ from io import BytesIO
|
|
| 66 |
|
| 67 |
|
| 68 |
class EndpointHandler:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def __init__(self, path: str = "", default_batch_size: int = 4):
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
| 72 |
|
| 73 |
-
# Load the model and processor
|
| 74 |
self.model = ColQwen2.from_pretrained(
|
| 75 |
path,
|
| 76 |
torch_dtype=torch.bfloat16,
|
| 77 |
).eval()
|
| 78 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
| 79 |
|
| 80 |
-
# Determine the device
|
| 81 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
self.model.to(self.device)
|
| 83 |
-
|
| 84 |
-
# Set default batch size
|
| 85 |
self.default_batch_size = default_batch_size
|
| 86 |
|
| 87 |
def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
batch_images = self.processor.process_images(images)
|
| 90 |
batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
|
| 91 |
|
| 92 |
-
# Generate embeddings
|
| 93 |
with torch.no_grad():
|
| 94 |
image_embeddings = self.model(**batch_images)
|
| 95 |
|
| 96 |
-
# Convert embeddings to list format
|
| 97 |
return image_embeddings.cpu().tolist()
|
| 98 |
|
| 99 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
images_data = data.get("inputs", [])
|
| 102 |
batch_size = data.get("batch_size", self.default_batch_size)
|
| 103 |
|
| 104 |
if not images_data:
|
| 105 |
return {"error": "No images provided in 'inputs'."}
|
| 106 |
|
| 107 |
-
# Decode and validate images
|
| 108 |
images = []
|
| 109 |
for img_data in images_data:
|
| 110 |
if isinstance(img_data, str):
|
| 111 |
try:
|
| 112 |
-
# Assume base64-encoded image
|
| 113 |
image_bytes = base64.b64decode(img_data)
|
| 114 |
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 115 |
images.append(image)
|
|
@@ -118,7 +142,6 @@ class EndpointHandler:
|
|
| 118 |
else:
|
| 119 |
return {"error": "Images should be base64-encoded strings."}
|
| 120 |
|
| 121 |
-
# Process in batches with the specified or default batch size
|
| 122 |
embeddings = []
|
| 123 |
for i in range(0, len(images), batch_size):
|
| 124 |
batch_images = images[i : i + batch_size]
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
class EndpointHandler:
|
| 69 |
+
"""
|
| 70 |
+
A handler class for processing image data, generating embeddings using a specified model and processor.
|
| 71 |
+
|
| 72 |
+
Attributes:
|
| 73 |
+
model: The pre-trained model used for generating embeddings.
|
| 74 |
+
processor: The pre-trained processor used to process images before model inference.
|
| 75 |
+
device: The device (CPU or CUDA) used to run model inference.
|
| 76 |
+
default_batch_size: The default batch size for processing images in batches.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
def __init__(self, path: str = "", default_batch_size: int = 4):
|
| 80 |
+
"""
|
| 81 |
+
Initializes the EndpointHandler with a specified model path and default batch size.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
path (str): Path to the pre-trained model and processor.
|
| 85 |
+
default_batch_size (int): Default batch size for image processing.
|
| 86 |
+
"""
|
| 87 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
| 88 |
|
|
|
|
| 89 |
self.model = ColQwen2.from_pretrained(
|
| 90 |
path,
|
| 91 |
torch_dtype=torch.bfloat16,
|
| 92 |
).eval()
|
| 93 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
| 94 |
|
|
|
|
| 95 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 96 |
self.model.to(self.device)
|
|
|
|
|
|
|
| 97 |
self.default_batch_size = default_batch_size
|
| 98 |
|
| 99 |
def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
|
| 100 |
+
"""
|
| 101 |
+
Processes a batch of images and generates embeddings.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
images (List[Image.Image]): List of images to process.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
List[List[float]]: List of embeddings for each image.
|
| 108 |
+
"""
|
| 109 |
batch_images = self.processor.process_images(images)
|
| 110 |
batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
|
| 111 |
|
|
|
|
| 112 |
with torch.no_grad():
|
| 113 |
image_embeddings = self.model(**batch_images)
|
| 114 |
|
|
|
|
| 115 |
return image_embeddings.cpu().tolist()
|
| 116 |
|
| 117 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 118 |
+
"""
|
| 119 |
+
Processes input data containing base64-encoded images, decodes them, and generates embeddings.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
data (Dict[str, Any]): Dictionary containing input images and optional batch size.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Dict[str, Any]: Dictionary containing generated embeddings or error messages.
|
| 126 |
+
"""
|
| 127 |
images_data = data.get("inputs", [])
|
| 128 |
batch_size = data.get("batch_size", self.default_batch_size)
|
| 129 |
|
| 130 |
if not images_data:
|
| 131 |
return {"error": "No images provided in 'inputs'."}
|
| 132 |
|
|
|
|
| 133 |
images = []
|
| 134 |
for img_data in images_data:
|
| 135 |
if isinstance(img_data, str):
|
| 136 |
try:
|
|
|
|
| 137 |
image_bytes = base64.b64decode(img_data)
|
| 138 |
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 139 |
images.append(image)
|
|
|
|
| 142 |
else:
|
| 143 |
return {"error": "Images should be base64-encoded strings."}
|
| 144 |
|
|
|
|
| 145 |
embeddings = []
|
| 146 |
for i in range(0, len(images), batch_size):
|
| 147 |
batch_images = images[i : i + batch_size]
|