AdGenesis-App / multimodel_services /replicate_generation_service.py
userIdc2024's picture
Update multimodel_services/replicate_generation_service.py
1fe7152 verified
import replicate
import os
import requests
from dotenv import load_dotenv
from typing import Dict, Any, Optional
load_dotenv()
def generate_image_with_model(model_name: str, prompt: str, ui_params: Optional[Dict] = None,
base64_image: Optional[str] = None) -> Optional[bytes]:
"""Generate image with selected Replicate model"""
try:
# Initialize Replicate client
api_token = os.getenv("REPLICATE_API_TOKEN") or os.getenv("REPLICATE_API_KEY")
if not api_token:
raise Exception("Replicate API token not found. Please set REPLICATE_API_TOKEN environment variable.")
client = replicate.Client(api_token=api_token)
# Get all parameters (UI + defaults)
from multimodel_services.model_manager import get_all_parameters
all_params = get_all_parameters(model_name, ui_params)
# Build input dict
input_data = {
"prompt": prompt,
**all_params
}
# Add image input if model supports it and image is provided
if base64_image:
# Convert base64 to data URL for Replicate
image_data_url = f"data:image/jpeg;base64,{base64_image}"
# Google Nano Banana expects image_input as an array
input_data["image_input"] = [image_data_url]
# print(f"Generating with {model_name}")
# Generate image
output = client.run(model_name, input=input_data)
# Handle different output types
if hasattr(output, 'read'):
return output.read()
elif hasattr(output, 'url'):
# Fetch image data from URL
response = requests.get(output.url())
return response.content
elif isinstance(output, list) and len(output) > 0:
# Multiple outputs, take first
response = requests.get(output[0])
return response.content
else:
# Direct URL string
response = requests.get(str(output))
return response.content
except Exception as e:
print(f"Replicate generation failed: {str(e)}")
raise Exception(f"Failed to generate image: {str(e)}")
def convert_size_to_aspect_ratio(size: str, model_name: str) -> str:
"""Convert size parameter to aspect ratio for specific models"""
size_mapping = {
"1024x1024": "1:1",
"1536x1024": "3:2",
"1024x1536": "2:3"
}
if model_name in ["google/nano-banana"]:
return size_mapping.get(size, "1:1")
return size