Spaces:
Sleeping
Sleeping
Mohamed Abdel-Moneim
fix: enhance squeeze_mask method to handle various mask shapes and add logging in image_resegment_job
71db628 | import asyncio | |
| from enum import verify | |
| import io | |
| import random | |
| from PIL import Image | |
| from typing import Any, List, Optional | |
| from app.services.sam_service import SAM_service | |
| import httpx | |
| import numpy as np | |
| http_client = httpx.AsyncClient(timeout=30.0) | |
| async def download_image(image_url: str, job_id: str) -> Optional[Image.Image]: | |
| print(f"--- [Segment {job_id}] Downloading image from: {image_url} ---") | |
| try: | |
| response = await http_client.get(image_url) | |
| response.raise_for_status() | |
| image_bytes = response.content | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| print(f"--- [Segment {job_id}] Downloaded {len(image_bytes)} bytes. ---") | |
| return image | |
| except httpx.HTTPStatusError as e: | |
| print(f"--- [Segment {job_id}] ERROR: HTTP error while downloading: {e} ---") | |
| except httpx.RequestError as e: | |
| print(f"--- [Segment {job_id}] ERROR: Network error while downloading: {e} ---") | |
| return None | |
| async def image_segment_job( | |
| job_id: str, | |
| sam_instance: SAM_service, | |
| image_url: str, | |
| TRUSTED_HOST: str, | |
| callback_url: Optional[str] = None | |
| ): | |
| """ | |
| This is background worker function. | |
| It runs *after* the API response has been sent. | |
| It will download the image from the URL and process it. | |
| this will be changed to Redis job worker in the future | |
| """ | |
| print(f"--- [Segment {job_id}] STARTING ---") | |
| try: | |
| url_host = httpx.URL(image_url).host | |
| if url_host != TRUSTED_HOST: | |
| print(f"--- [Segment {job_id}] ERROR: Untrusted URL host: {url_host} ---") | |
| return | |
| except Exception as e: | |
| print(f"--- [Segment {job_id}] ERROR: Invalid URL: {e} ---") | |
| return | |
| print(f"--- [Segment {job_id}] Downloading image from: {image_url} ---") | |
| try: | |
| response = await http_client.get(image_url) | |
| response.raise_for_status() | |
| image_bytes = response.content | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| print(f"--- [Segment {job_id}] Downloaded {len(image_bytes)} bytes, Mode={image.mode}, Size={image.size} ---") | |
| except httpx.HTTPStatusError as e: | |
| print(f"--- [Segment {job_id}] ERROR: HTTP error while downloading: {e} ---") | |
| return | |
| except httpx.RequestError as e: | |
| print(f"--- [Segment {job_id}] ERROR: Network error while downloading: {e} ---") | |
| return | |
| # Here, we would pass 'image_bytes' to ML model, | |
| # image processing library, etc. | |
| print(f"--- [Segment {job_id}] Processing image... ---") | |
| masks , boxes, _ = sam_instance.segment_with_prompt(image, "clothes. shoe. shirt. glove. pants. boot.".lower()) | |
| print(f"--- [Segment Job {job_id}] COMPLETED ---") | |
| if callback_url: | |
| print(f"--- [Segment {job_id}] Notifying callback URL: {callback_url} ---") | |
| # Simulate some results | |
| if len(masks.shape) == 4: | |
| masks = np.squeeze(masks, axis=1) | |
| masks = sam_instance.squeeze_mask(masks).astype(np.uint8) | |
| results = { | |
| "job_id": job_id, | |
| "status": "segmented", | |
| "masks": [masks[i].tolist() for i in range(len(masks))], | |
| "boxes": boxes.tolist() | |
| } | |
| try: | |
| response = await http_client.put(callback_url, json=results) | |
| response.raise_for_status() | |
| print(f"--- [Segment {job_id}] Callback notification successful (Status: {response.status_code}) ---") | |
| except httpx.HTTPStatusError as e: | |
| # non-2xx HTTP response | |
| print(f"--- [Segment {job_id}] HTTPStatusError: {e.response.status_code}, Body: {e.response.text} ---") | |
| except httpx.ReadError as e: | |
| # network read failure (like your case) | |
| print(f"--- [Segment {job_id}] ReadError: {str(e)} ---") | |
| except httpx.RequestError as e: | |
| # any other network issue | |
| print(f"--- [Segment {job_id}] RequestError: {type(e).__name__}: {str(e)} ---") | |
| async def image_resegment_job( | |
| job_id: str, | |
| sam_instance: SAM_service, | |
| image_url: str, | |
| TRUSTED_HOST: str, | |
| pos_points: List[List[int]], | |
| neg_points: List[List[int]], | |
| boxes: Optional[List[List[int]]] = None, | |
| callback_url: Optional[str] = None | |
| ): | |
| """ | |
| This is background worker function for re-segment job. | |
| It runs *after* the API response has been sent. | |
| It will download the image from the URL and process it using the points. | |
| this will be changed to Redis job worker in the future. | |
| """ | |
| print(f"--- [Re-Segment Job {job_id}] STARTING ---") | |
| print(f"--- [Re-Segment Job {job_id}] Received {len(pos_points) + len(neg_points)} points ---") | |
| try: | |
| url_host = httpx.URL(image_url).host | |
| if url_host != TRUSTED_HOST: | |
| print(f"--- [Re-Segment Job {job_id}] ERROR: Untrusted URL host: {url_host} ---") | |
| return | |
| except Exception as e: | |
| print(f"--- [Re-Segment Job {job_id}] ERROR: Invalid URL: {e} ---") | |
| return | |
| # 2. Download the image | |
| print(f"--- [Re-Segment Job {job_id}] Downloading image from: {image_url} ---") | |
| try: | |
| response = await http_client.get(image_url) | |
| response.raise_for_status() | |
| image_bytes = response.content | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| print(f"--- [Segment {job_id}] Downloaded {len(image_bytes)} bytes, Mode={image.mode}, Size={image.size} ---") | |
| except httpx.HTTPStatusError as e: | |
| print(f"--- [Re-Segment Job {job_id}] ERROR: HTTP error while downloading: {e} ---") | |
| return | |
| except httpx.RequestError as e: | |
| print(f"--- [Re-Segment Job {job_id}] ERROR: Network error while downloading: {e} ---") | |
| return | |
| print(f"--- [Re-Segment Job {job_id}] Processing image with points... ---") | |
| masks, _, _ = sam_instance.resegment(image, pos_points, neg_points, boxes=boxes) | |
| # 4. --- Job Finished --- | |
| print(f"--- [Re-Segment Job {job_id}] COMPLETED ---") | |
| if callback_url: | |
| print(f"--- [Re-Segment Job {job_id}] Notifying callback URL: {callback_url} ---") | |
| # Simulate some results | |
| if len(masks.shape) == 4: | |
| masks = np.squeeze(masks, axis=1) | |
| masks = sam_instance.squeeze_mask(masks).astype(np.uint8) | |
| print(f"--- [Re-Segment Job {job_id}] Generated {masks.shape} masks ---") | |
| results = { | |
| "job_id": job_id, | |
| "status": "re-segmented", | |
| "masks": [masks[i].tolist() for i in range(len(masks))] | |
| } | |
| try: | |
| response = await http_client.put(callback_url, json=results) | |
| response.raise_for_status() # Raise exception for 4xx/5xx errors | |
| print(f"--- [Re-Segment Job {job_id}] Callback notification successful (Status: {response.status_code}) ---") | |
| except httpx.RequestError as e: | |
| # Handle network errors, timeouts, etc. | |
| print(f"--- [Segment {job_id}] ERROR: {type(e).__name__}: {str(e)} ---") | |
| except httpx.HTTPStatusError as e: | |
| # Handle non-2xx responses | |
| print(f"--- [Re-Segment Job {job_id}] ERROR: Callback notification received non-2xx status: {e.response.status_code} ---") | |