Spaces:
Running
Running
| import replicate | |
| from PIL import Image | |
| import io | |
| import requests | |
| import base64 | |
| def generate_image( | |
| prompt, | |
| num_steps=30, | |
| guidance_scale=7.5, | |
| aspect_ratio="1:1", | |
| replicate_api_key=None, | |
| lora_url=None, | |
| negative_prompt=None | |
| ): | |
| """ | |
| Generate an image using Stable Diffusion via Replicate API | |
| Args: | |
| prompt (str): The text prompt for image generation | |
| num_steps (int): Number of inference steps | |
| guidance_scale (float): Guidance scale for generation | |
| aspect_ratio (str): Desired aspect ratio ("1:1", "16:9", "3:2", etc.) | |
| replicate_api_key (str): API key for Replicate | |
| lora_url (str, optional): URL to LoRA weights | |
| negative_prompt (str, optional): Negative prompt for generation | |
| """ | |
| try: | |
| if not replicate_api_key: | |
| return None, "Please provide a Replicate API key" | |
| # Set up aspect ratio dimensions | |
| aspect_ratios = { | |
| "1:1": (512, 512), | |
| "16:9": (912, 512), | |
| "3:2": (768, 512), | |
| "2:3": (512, 768), | |
| "4:5": (512, 640), | |
| "5:4": (640, 512) | |
| } | |
| width, height = aspect_ratios.get(aspect_ratio, (512, 512)) | |
| # Configure model parameters | |
| model_params = { | |
| "prompt": prompt, | |
| "negative_prompt": negative_prompt or "ugly, blurry, low quality, distorted, deformed", | |
| "num_inference_steps": num_steps, | |
| "guidance_scale": guidance_scale, | |
| "width": width, | |
| "height": height, | |
| "scheduler": "DPMSolverMultistep", # You can experiment with different schedulers | |
| "num_outputs": 1, | |
| } | |
| # Add LoRA if specified | |
| if lora_url: | |
| model_params["lora_urls"] = lora_url | |
| # Set API key | |
| client = replicate.Client(api_token=replicate_api_key) | |
| # Run the model | |
| # Using SDXL model for better quality | |
| output = client.run( | |
| "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", | |
| input=model_params | |
| ) | |
| # Get the image URL from output | |
| if output and len(output) > 0: | |
| image_url = output[0] | |
| # Download and convert to PIL Image | |
| response = requests.get(image_url) | |
| if response.status_code == 200: | |
| image = Image.open(io.BytesIO(response.content)) | |
| return image, "Success" | |
| else: | |
| return None, f"Failed to download image: {response.status_code}" | |
| else: | |
| return None, "No image generated" | |
| except Exception as e: | |
| return None, f"Error generating image: {str(e)}" | |
| def encode_image_to_base64(image): | |
| """Helper function to convert PIL Image to base64 string""" | |
| if isinstance(image, Image.Image): | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| return None |