Spaces:
Sleeping
Sleeping
| import os | |
| import replicate | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import tempfile | |
| from typing import Optional | |
| class ReplicateHandler: | |
| def __init__(self, model: str, default_settings: dict): | |
| self.model = model | |
| self.default_settings = default_settings | |
| # Verify API token | |
| api_token = os.getenv("REPLICATE_API_TOKEN") | |
| if not api_token: | |
| raise ValueError("REPLICATE_API_TOKEN not found in environment variables") | |
| def _image_to_base64_url(self, image: Image.Image) -> str: | |
| """Convert PIL Image to base64 data URL for Replicate""" | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_str}" | |
| def _save_temp_image(self, image: Image.Image) -> str: | |
| """Save image to temp file and return path""" | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| image.save(temp_file.name, format="PNG") | |
| return temp_file.name | |
| def generate( | |
| self, | |
| input_image: Image.Image, | |
| prompt: str, | |
| negative_prompt: str, | |
| custom_settings: Optional[dict] = None | |
| ) -> Image.Image: | |
| """ | |
| Generate image using Replicate InstantID | |
| Args: | |
| input_image: PIL Image | |
| prompt: Positive prompt | |
| negative_prompt: Negative prompt | |
| custom_settings: Override default settings (cfg, steps, etc.) | |
| Returns: | |
| Generated PIL Image | |
| """ | |
| # Merge settings | |
| settings = {**self.default_settings} | |
| if custom_settings: | |
| settings.update(custom_settings) | |
| # Save temp image and get file object | |
| temp_path = self._save_temp_image(input_image) | |
| try: | |
| # Prepare input | |
| input_params = { | |
| "image": open(temp_path, "rb"), | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt, | |
| **settings | |
| } | |
| # Run prediction (streaming) | |
| output = replicate.run(self.model, input=input_params) | |
| # Get final image from iterator | |
| result_url = None | |
| for item in output: | |
| result_url = item # Last item is the final image URL | |
| if not result_url: | |
| raise ValueError("No output received from Replicate") | |
| # Download and convert to PIL | |
| import requests | |
| response = requests.get(result_url) | |
| result_image = Image.open(io.BytesIO(response.content)) | |
| return result_image | |
| finally: | |
| # Cleanup temp file | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) |