Spaces:
Sleeping
Sleeping
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| from fastapi import HTTPException | |
| def base64_to_image(base64_str: str) -> Image.Image: | |
| """Convert base64 string to PIL Image. | |
| Args: | |
| base64_str: Base64 encoded image string | |
| Returns: | |
| PIL.Image: Decoded image | |
| Raises: | |
| HTTPException: If base64 string is invalid | |
| """ | |
| try: | |
| # Handle frontend base64 format (data:image/jpeg;base64,{base64_data}) | |
| if "," in base64_str: | |
| base64_str = base64_str.split(",", 1)[1] | |
| image_data = base64.b64decode(base64_str) | |
| image = Image.open(BytesIO(image_data)) | |
| # Convert RGBA to RGB if necessary | |
| if image.mode in ('RGBA', 'LA'): | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| if image.mode == 'RGBA': | |
| background.paste(image, mask=image.split()[3]) # 3 is the alpha channel | |
| else: | |
| background.paste(image, mask=image.split()[1]) # 1 is the alpha channel | |
| image = background | |
| elif image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return image | |
| except Exception as e: | |
| print(f"Base64 decoding error: {str(e)}") | |
| raise HTTPException(status_code=400, detail=f"Invalid Base64 image: {str(e)}") | |
| def image_to_base64(image: Image.Image, format: str = "JPEG") -> str: | |
| """Convert PIL Image to base64 string. | |
| Args: | |
| image: PIL Image object | |
| format: Output format (JPEG, PNG, etc.) | |
| Returns: | |
| str: Base64 encoded image string | |
| """ | |
| try: | |
| # Convert RGBA to RGB if saving as JPEG | |
| if format.upper() == "JPEG" and image.mode in ('RGBA', 'LA'): | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| if image.mode == 'RGBA': | |
| background.paste(image, mask=image.split()[3]) | |
| else: | |
| background.paste(image, mask=image.split()[1]) | |
| image = background | |
| elif format.upper() == "JPEG" and image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| buffered = BytesIO() | |
| image.save(buffered, format=format) | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| except Exception as e: | |
| print(f"Error converting image to base64: {str(e)}") | |
| # Try JPEG as fallback | |
| if format.upper() != "JPEG": | |
| return image_to_base64(image, format="JPEG") | |
| raise | |
| def is_image_file(filename: str) -> bool: | |
| """Check if a filename has a valid image extension. | |
| Args: | |
| filename: Name of the file to check | |
| Returns: | |
| bool: True if file has valid image extension | |
| """ | |
| valid_extensions = (".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tiff", ".webp") | |
| return filename.lower().endswith(valid_extensions) | |
| def get_image_format(filename: str) -> str: | |
| """Get the format to use for saving an image based on its filename. | |
| Args: | |
| filename: Name of the file | |
| Returns: | |
| str: Format to use (JPEG, PNG, etc.) | |
| """ | |
| ext = filename.lower().split('.')[-1] | |
| if ext in ('jpg', 'jpeg'): | |
| return 'JPEG' | |
| elif ext == 'png': | |
| return 'PNG' | |
| elif ext == 'webp': | |
| return 'WEBP' | |
| elif ext == 'gif': | |
| return 'GIF' | |
| else: | |
| return 'JPEG' # Default to JPEG |