Flux-Trainer / src /replicate_call.py
Daniel Jarvis
Application files V1
f4a907c
import os
import tempfile
import shutil
from pathlib import Path
from typing import List, Optional, Tuple, Dict
import requests
import time
from src.backblaze_storage import BB_uploadfile
try:
import replicate
REPLICATE_AVAILABLE = True
except ImportError:
REPLICATE_AVAILABLE = False
class ReplicatePortraitAPI:
def __init__(self, api_token: Optional[str] = None):
"""Initialize Replicate API client"""
if not REPLICATE_AVAILABLE:
raise ImportError("Replicate package not installed. Run: pip install replicate")
self.api_token = api_token or os.getenv('REPLICATE_API_TOKEN')
if not self.api_token:
raise ValueError("REPLICATE_API_TOKEN environment variable or api_token parameter is required")
# Set the API token for the replicate client
os.environ['REPLICATE_API_TOKEN'] = self.api_token
self.portrait_model = "flux-kontext-apps/portrait-series"
self.trainer_model = "replicate/fast-flux-trainer:8b10794665aed907bb98a1a5324cd1d3a8bea0e9b31e65210967fb9c9e2e08ed"
# Initialize client
self.client = replicate.Client(api_token=self.api_token)
def upload_file_to_replicate(self, file_path: str) -> str:
"""Upload file to Replicate and get URL"""
try:
with open(file_path, 'rb') as file:
uploaded_file = self.client.files.create(file)
return uploaded_file.urls['get']
except Exception as e:
# Fallback: convert to data URL for images only
if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.gif')):
import base64
with open(file_path, 'rb') as img_file:
img_data = img_file.read()
img_b64 = base64.b64encode(img_data).decode()
# Determine MIME type
ext = Path(file_path).suffix.lower()
mime_types = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.webp': 'image/webp',
'.gif': 'image/gif'
}
mime_type = mime_types.get(ext, 'image/jpeg')
return f"data:{mime_type};base64,{img_b64}"
else:
raise Exception(f"Failed to upload file: {str(e)}")
def download_images(self, image_urls: List[str], download_dir: str) -> List[str]:
"""Download images from URLs to local directory"""
downloaded_paths = []
for i, url in enumerate(image_urls):
try:
response = requests.get(url, stream=True)
response.raise_for_status()
# Generate filename
filename = f"portrait_{i+1:02d}.png"
filepath = os.path.join(download_dir, filename)
# Download image
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded_paths.append(filepath)
except Exception as e:
print(f"Error downloading image {i+1}: {e}")
continue
return downloaded_paths
def generate_portrait_series(self,
input_image_path: str,
num_images: int = 4,
background: str = "black",
randomize_images: bool = True,
output_format: str = "png",
safety_tolerance: int = 1,
download_dir: Optional[str] = None) -> Tuple[List[str], dict]:
"""
Generate portrait series using Replicate API
Returns:
Tuple of (downloaded_image_paths, api_response)
"""
# Create download directory if not provided
if download_dir is None:
download_dir = tempfile.mkdtemp(prefix="portrait_series_")
else:
os.makedirs(download_dir, exist_ok=True)
try:
# Upload input image
image_url = self.upload_file_to_replicate(input_image_path)
# Prepare input data
input_data = {
"input_image": image_url,
"num_images": num_images,
"background": background,
"randomize_images": randomize_images,
"output_format": output_format,
"safety_tolerance": safety_tolerance
}
# Run the model - this handles everything automatically!
print(f"πŸ”„ Running {self.portrait_model} with {num_images} images...")
output = replicate.run(
self.portrait_model,
input=input_data
)
# The output is a list of image URLs
if not output:
raise Exception("No output images generated")
print(f"βœ… Generated {len(output)} images, downloading...")
# Download images
downloaded_paths = self.download_images(output, download_dir)
# Create response dict for compatibility
response = {
"output": output,
"input": input_data,
"status": "succeeded",
"model": self.portrait_model
}
return downloaded_paths, response
except Exception as e:
raise Exception(f"Error in portrait generation: {str(e)}")
def start_flux_training(self,
input_images_zip: str,
destination: str,
trigger_word: str,
lora_type: str = "subject") -> Dict:
"""
Start training a Fast Flux LoRA model
Args:
input_images_zip: Path to zip file containing training images OR URL to uploaded zip
destination: Replicate model destination (username/model-name)
trigger_word: Unique trigger word for the model
lora_type: Type of training - "subject" or "style"
Returns:
Dict containing training information
"""
try:
# Upload zip file if it's a local path
print(f"πŸ“€ Uploading training data: {input_images_zip}")
zip_url = BB_uploadfile(input_images_zip, os.path.basename(input_images_zip))
# Prepare training input
training_input = {
"input_images": zip_url,
"trigger_word": trigger_word.lower(),
"lora_type": lora_type.lower()
}
print(f"πŸš€ Starting Fast Flux training...")
print(f" Destination: {destination}")
print(f" Trigger word: {trigger_word}")
print(f" LoRA type: {lora_type}")
# Create model if it doesn't exist
try:
owner, name = destination.split("/")
model = self.client.models.create(
owner=owner.lower(),
name=name.lower(),
visibility="public",
hardware="gpu-a100-large"
)
print(f"βœ… Model created! ID: {model.id}")
except Exception as e:
error_message = f"Error creating model: {str(e)}"
print(error_message)
# Continue anyway in case model already exists
# Create training
model_name, version = self.trainer_model.split(":")
# Fixed the main issue: use 'destination' parameter instead of 'self.destination'
# Also fixed the typo: 'tranining' -> 'training'
training = self.client.trainings.create(
model=model_name.lower(),
version=version.lower(),
input=training_input,
destination=destination # This was the main bug - was self.destination before
)
training_info = {
"id": training.id,
"status": training.status,
"destination": destination,
"trigger_word": trigger_word,
"lora_type": lora_type,
"created_at": getattr(training, 'created_at', None),
"urls": getattr(training, 'urls', {}),
"input": training_input
}
print(f"βœ… Training started! ID: {training.id}")
return training_info
except Exception as e:
raise Exception(f"Error starting training: {str(e)}")
def get_training_status(self, training_id: str) -> Dict:
"""Get the status of a training"""
try:
training = self.client.trainings.get(training_id)
return {
"id": training.id,
"status": training.status,
"created_at": getattr(training, 'created_at', None),
"completed_at": getattr(training, 'completed_at', None),
"error": getattr(training, 'error', None),
"logs": getattr(training, 'logs', None),
"urls": getattr(training, 'urls', {}),
"output": getattr(training, 'output', None)
}
except Exception as e:
raise Exception(f"Error getting training status: {str(e)}")
def wait_for_training_completion(self, training_id: str, max_wait_time: int = 3600, callback=None) -> Dict:
"""
Wait for training to complete
Args:
training_id: Training ID to monitor
max_wait_time: Maximum time to wait in seconds (default 1 hour)
callback: Optional callback function to call with status updates
Returns:
Final training status dict
"""
start_time = time.time()
last_status = None
while time.time() - start_time < max_wait_time:
try:
status = self.get_training_status(training_id)
current_status = status.get('status', 'unknown')
# Call callback if status changed
if callback and current_status != last_status:
callback(status)
last_status = current_status
if current_status == 'succeeded':
print(f"βœ… Training completed successfully!")
return status
elif current_status == 'failed':
error_msg = status.get('error', 'Unknown error occurred')
raise Exception(f"Training failed: {error_msg}")
elif current_status in ['canceled', 'cancelled']:
raise Exception("Training was canceled")
# Still processing, wait a bit
time.sleep(30) # Check every 30 seconds for training
except Exception as e:
if "Training failed" in str(e) or "canceled" in str(e):
raise
# For other errors, continue waiting
time.sleep(30)
raise Exception(f"Training timed out after {max_wait_time} seconds")
def list_user_models(self, username: str) -> List[Dict]:
"""List models for a user"""
try:
models = self.client.models.list()
user_models = []
for model in models:
if hasattr(model, 'owner') and model.owner == username:
user_models.append({
"name": model.name,
"owner": model.owner,
"description": getattr(model, 'description', ''),
"full_name": f"{model.owner}/{model.name}"
})
return user_models
except Exception as e:
print(f"Error listing models: {e}")
return []
def test_api():
"""Test function to verify API functionality"""
if not REPLICATE_AVAILABLE:
print("❌ Replicate package not installed. Run: pip install replicate")
return False
try:
api = ReplicatePortraitAPI()
print("βœ… API initialized successfully")
return True
except Exception as e:
print(f"❌ API initialization failed: {e}")
return False
def quick_test_with_sample():
"""Quick test with a sample image URL"""
if not REPLICATE_AVAILABLE:
print("❌ Replicate package not available")
return
try:
# This is a quick test using the example from the docs
output = replicate.run(
"flux-kontext-apps/portrait-series",
input={
"background": "black",
"num_images": 2, # Small number for testing
"input_image": "https://replicate.delivery/pbxt/N5DZJkCEuP5rWGtu8XcfyZj9sXzm4W3OXOSfdJnj9NmlirP2/mona-lisa.png",
"output_format": "png",
"randomize_images": True,
"safety_tolerance": 1
}
)
print(f"βœ… Test successful! Generated {len(output)} images")
print("Sample URLs:", output[:2])
return True
except Exception as e:
print(f"❌ Test failed: {e}")
return False
if __name__ == "__main__":
print("🎨 Replicate Portrait Series & Fast Flux Training API")
print("=" * 50)
# Test basic initialization
if test_api():
print("\nπŸš€ API ready to use!")
# Optionally run a quick test (uncomment to test)
# print("\nπŸ§ͺ Running quick test...")
# quick_test_with_sample()
else:
print("\nπŸ“‹ Setup instructions:")
print("1. pip install replicate")
print("2. Set REPLICATE_API_TOKEN environment variable")
print("3. Get token from: https://replicate.com/account")