|
|
import base64 |
|
|
import io |
|
|
from typing import Any, Dict, Optional, Union |
|
|
from huggingface_hub import InferenceClient |
|
|
from PIL import Image |
|
|
|
|
|
def encode_image(image: Image.Image) -> str: |
|
|
"""Encodes a PIL Image to base64 string.""" |
|
|
buffered = io.BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
def decode_image(image_data: Union[str, bytes]) -> Image.Image: |
|
|
"""Decodes base64 string or bytes to PIL Image.""" |
|
|
if isinstance(image_data, str): |
|
|
|
|
|
if image_data.startswith("http://") or image_data.startswith("https://"): |
|
|
import requests |
|
|
response = requests.get(image_data) |
|
|
response.raise_for_status() |
|
|
image_data = response.content |
|
|
else: |
|
|
|
|
|
image_data = base64.b64decode(image_data) |
|
|
return Image.open(io.BytesIO(image_data)) |
|
|
|
|
|
def handle_hf_error(func): |
|
|
"""Decorator to handle Hugging Face API errors gracefully.""" |
|
|
def wrapper(*args, **kwargs): |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
except Exception as e: |
|
|
return f"Error executing task: {str(e)}" |
|
|
return wrapper |
|
|
|
|
|
@handle_hf_error |
|
|
def run_text_generation(client: InferenceClient, prompt: str, model: Optional[str] = None, **kwargs) -> str: |
|
|
return client.text_generation(prompt, model=model, **kwargs) |
|
|
|
|
|
@handle_hf_error |
|
|
def run_image_generation(client: InferenceClient, prompt: str, model: Optional[str] = None, **kwargs) -> Image.Image: |
|
|
return client.text_to_image(prompt, model=model, **kwargs) |
|
|
|
|
|
|
|
|
|