File size: 2,642 Bytes
374de61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe7152
374de61
 
2e74b02
374de61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import replicate
import os
import requests
from dotenv import load_dotenv
from typing import Dict, Any, Optional

load_dotenv()

def generate_image_with_model(model_name: str, prompt: str, ui_params: Optional[Dict] = None, 
                            base64_image: Optional[str] = None) -> Optional[bytes]:
    """Generate image with selected Replicate model"""
    try:
        # Initialize Replicate client
        api_token = os.getenv("REPLICATE_API_TOKEN") or os.getenv("REPLICATE_API_KEY")
        if not api_token:
            raise Exception("Replicate API token not found. Please set REPLICATE_API_TOKEN environment variable.")
            
        client = replicate.Client(api_token=api_token)
        
        # Get all parameters (UI + defaults)
        from multimodel_services.model_manager import get_all_parameters
        all_params = get_all_parameters(model_name, ui_params)
        
        # Build input dict
        input_data = {
            "prompt": prompt,
            **all_params
        }
        
        # Add image input if model supports it and image is provided
        if base64_image:
            # Convert base64 to data URL for Replicate
            image_data_url = f"data:image/jpeg;base64,{base64_image}"
            # Google Nano Banana expects image_input as an array
            input_data["image_input"] = [image_data_url]
        
        # print(f"Generating with {model_name}")
        
        # Generate image
        output = client.run(model_name, input=input_data)
        
        # Handle different output types
        if hasattr(output, 'read'):
            return output.read()
        elif hasattr(output, 'url'):
            # Fetch image data from URL
            response = requests.get(output.url())
            return response.content
        elif isinstance(output, list) and len(output) > 0:
            # Multiple outputs, take first
            response = requests.get(output[0])
            return response.content
        else:
            # Direct URL string
            response = requests.get(str(output))
            return response.content
            
    except Exception as e:
        print(f"Replicate generation failed: {str(e)}")
        raise Exception(f"Failed to generate image: {str(e)}")

def convert_size_to_aspect_ratio(size: str, model_name: str) -> str:
    """Convert size parameter to aspect ratio for specific models"""
    size_mapping = {
        "1024x1024": "1:1",
        "1536x1024": "3:2", 
        "1024x1536": "2:3"
    }
    
    if model_name in ["google/nano-banana"]:
        return size_mapping.get(size, "1:1")
    
    return size