File size: 3,094 Bytes
cb9b654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import replicate
from PIL import Image
import io
import requests
import base64

def generate_image(
    prompt,
    num_steps=30,
    guidance_scale=7.5,
    aspect_ratio="1:1",
    replicate_api_key=None,
    lora_url=None,
    negative_prompt=None
):
    """
    Generate an image using Stable Diffusion via Replicate API
    
    Args:
        prompt (str): The text prompt for image generation
        num_steps (int): Number of inference steps
        guidance_scale (float): Guidance scale for generation
        aspect_ratio (str): Desired aspect ratio ("1:1", "16:9", "3:2", etc.)
        replicate_api_key (str): API key for Replicate
        lora_url (str, optional): URL to LoRA weights
        negative_prompt (str, optional): Negative prompt for generation
    """
    try:
        if not replicate_api_key:
            return None, "Please provide a Replicate API key"
            
        # Set up aspect ratio dimensions
        aspect_ratios = {
            "1:1": (512, 512),
            "16:9": (912, 512),
            "3:2": (768, 512),
            "2:3": (512, 768),
            "4:5": (512, 640),
            "5:4": (640, 512)
        }
        width, height = aspect_ratios.get(aspect_ratio, (512, 512))
        
        # Configure model parameters
        model_params = {
            "prompt": prompt,
            "negative_prompt": negative_prompt or "ugly, blurry, low quality, distorted, deformed",
            "num_inference_steps": num_steps,
            "guidance_scale": guidance_scale,
            "width": width,
            "height": height,
            "scheduler": "DPMSolverMultistep",  # You can experiment with different schedulers
            "num_outputs": 1,
        }
        
        # Add LoRA if specified
        if lora_url:
            model_params["lora_urls"] = lora_url
        
        # Set API key
        client = replicate.Client(api_token=replicate_api_key)
        
        # Run the model
        # Using SDXL model for better quality
        output = client.run(
            "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
            input=model_params
        )
        
        # Get the image URL from output
        if output and len(output) > 0:
            image_url = output[0]
            
            # Download and convert to PIL Image
            response = requests.get(image_url)
            if response.status_code == 200:
                image = Image.open(io.BytesIO(response.content))
                return image, "Success"
            else:
                return None, f"Failed to download image: {response.status_code}"
        else:
            return None, "No image generated"
        
    except Exception as e:
        return None, f"Error generating image: {str(e)}"

def encode_image_to_base64(image):
    """Helper function to convert PIL Image to base64 string"""
    if isinstance(image, Image.Image):
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')
    return None