File size: 2,928 Bytes
d7c5b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import replicate
from PIL import Image
import io
import base64
import tempfile
from typing import Optional


class ReplicateHandler:
    def __init__(self, model: str, default_settings: dict):
        self.model = model
        self.default_settings = default_settings
        
        # Verify API token
        api_token = os.getenv("REPLICATE_API_TOKEN")
        if not api_token:
            raise ValueError("REPLICATE_API_TOKEN not found in environment variables")
    
    def _image_to_base64_url(self, image: Image.Image) -> str:
        """Convert PIL Image to base64 data URL for Replicate"""
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        return f"data:image/png;base64,{img_str}"
    
    def _save_temp_image(self, image: Image.Image) -> str:
        """Save image to temp file and return path"""
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
        image.save(temp_file.name, format="PNG")
        return temp_file.name
    
    def generate(
        self,
        input_image: Image.Image,
        prompt: str,
        negative_prompt: str,
        custom_settings: Optional[dict] = None
    ) -> Image.Image:
        """
        Generate image using Replicate InstantID
        
        Args:
            input_image: PIL Image
            prompt: Positive prompt
            negative_prompt: Negative prompt
            custom_settings: Override default settings (cfg, steps, etc.)
        
        Returns:
            Generated PIL Image
        """
        # Merge settings
        settings = {**self.default_settings}
        if custom_settings:
            settings.update(custom_settings)
        
        # Save temp image and get file object
        temp_path = self._save_temp_image(input_image)
        
        try:
            # Prepare input
            input_params = {
                "image": open(temp_path, "rb"),
                "prompt": prompt,
                "negative_prompt": negative_prompt,
                **settings
            }
            
            # Run prediction (streaming)
            output = replicate.run(self.model, input=input_params)
            
            # Get final image from iterator
            result_url = None
            for item in output:
                result_url = item  # Last item is the final image URL
            
            if not result_url:
                raise ValueError("No output received from Replicate")
            
            # Download and convert to PIL
            import requests
            response = requests.get(result_url)
            result_image = Image.open(io.BytesIO(response.content))
            
            return result_image
            
        finally:
            # Cleanup temp file
            if os.path.exists(temp_path):
                os.unlink(temp_path)