import os import tempfile import shutil from pathlib import Path from typing import List, Optional, Tuple, Dict import requests import time from src.backblaze_storage import BB_uploadfile try: import replicate REPLICATE_AVAILABLE = True except ImportError: REPLICATE_AVAILABLE = False class ReplicatePortraitAPI: def __init__(self, api_token: Optional[str] = None): """Initialize Replicate API client""" if not REPLICATE_AVAILABLE: raise ImportError("Replicate package not installed. Run: pip install replicate") self.api_token = api_token or os.getenv('REPLICATE_API_TOKEN') if not self.api_token: raise ValueError("REPLICATE_API_TOKEN environment variable or api_token parameter is required") # Set the API token for the replicate client os.environ['REPLICATE_API_TOKEN'] = self.api_token self.portrait_model = "flux-kontext-apps/portrait-series" self.trainer_model = "replicate/fast-flux-trainer:8b10794665aed907bb98a1a5324cd1d3a8bea0e9b31e65210967fb9c9e2e08ed" # Initialize client self.client = replicate.Client(api_token=self.api_token) def upload_file_to_replicate(self, file_path: str) -> str: """Upload file to Replicate and get URL""" try: with open(file_path, 'rb') as file: uploaded_file = self.client.files.create(file) return uploaded_file.urls['get'] except Exception as e: # Fallback: convert to data URL for images only if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.gif')): import base64 with open(file_path, 'rb') as img_file: img_data = img_file.read() img_b64 = base64.b64encode(img_data).decode() # Determine MIME type ext = Path(file_path).suffix.lower() mime_types = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.webp': 'image/webp', '.gif': 'image/gif' } mime_type = mime_types.get(ext, 'image/jpeg') return f"data:{mime_type};base64,{img_b64}" else: raise Exception(f"Failed to upload file: {str(e)}") def download_images(self, image_urls: List[str], download_dir: str) -> List[str]: """Download images from URLs to local directory""" downloaded_paths = [] for i, url in enumerate(image_urls): try: response = requests.get(url, stream=True) response.raise_for_status() # Generate filename filename = f"portrait_{i+1:02d}.png" filepath = os.path.join(download_dir, filename) # Download image with open(filepath, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) downloaded_paths.append(filepath) except Exception as e: print(f"Error downloading image {i+1}: {e}") continue return downloaded_paths def generate_portrait_series(self, input_image_path: str, num_images: int = 4, background: str = "black", randomize_images: bool = True, output_format: str = "png", safety_tolerance: int = 1, download_dir: Optional[str] = None) -> Tuple[List[str], dict]: """ Generate portrait series using Replicate API Returns: Tuple of (downloaded_image_paths, api_response) """ # Create download directory if not provided if download_dir is None: download_dir = tempfile.mkdtemp(prefix="portrait_series_") else: os.makedirs(download_dir, exist_ok=True) try: # Upload input image image_url = self.upload_file_to_replicate(input_image_path) # Prepare input data input_data = { "input_image": image_url, "num_images": num_images, "background": background, "randomize_images": randomize_images, "output_format": output_format, "safety_tolerance": safety_tolerance } # Run the model - this handles everything automatically! print(f"๐Ÿ”„ Running {self.portrait_model} with {num_images} images...") output = replicate.run( self.portrait_model, input=input_data ) # The output is a list of image URLs if not output: raise Exception("No output images generated") print(f"โœ… Generated {len(output)} images, downloading...") # Download images downloaded_paths = self.download_images(output, download_dir) # Create response dict for compatibility response = { "output": output, "input": input_data, "status": "succeeded", "model": self.portrait_model } return downloaded_paths, response except Exception as e: raise Exception(f"Error in portrait generation: {str(e)}") def start_flux_training(self, input_images_zip: str, destination: str, trigger_word: str, lora_type: str = "subject") -> Dict: """ Start training a Fast Flux LoRA model Args: input_images_zip: Path to zip file containing training images OR URL to uploaded zip destination: Replicate model destination (username/model-name) trigger_word: Unique trigger word for the model lora_type: Type of training - "subject" or "style" Returns: Dict containing training information """ try: # Upload zip file if it's a local path print(f"๐Ÿ“ค Uploading training data: {input_images_zip}") zip_url = BB_uploadfile(input_images_zip, os.path.basename(input_images_zip)) # Prepare training input training_input = { "input_images": zip_url, "trigger_word": trigger_word.lower(), "lora_type": lora_type.lower() } print(f"๐Ÿš€ Starting Fast Flux training...") print(f" Destination: {destination}") print(f" Trigger word: {trigger_word}") print(f" LoRA type: {lora_type}") # Create model if it doesn't exist try: owner, name = destination.split("/") model = self.client.models.create( owner=owner.lower(), name=name.lower(), visibility="public", hardware="gpu-a100-large" ) print(f"โœ… Model created! ID: {model.id}") except Exception as e: error_message = f"Error creating model: {str(e)}" print(error_message) # Continue anyway in case model already exists # Create training model_name, version = self.trainer_model.split(":") # Fixed the main issue: use 'destination' parameter instead of 'self.destination' # Also fixed the typo: 'tranining' -> 'training' training = self.client.trainings.create( model=model_name.lower(), version=version.lower(), input=training_input, destination=destination # This was the main bug - was self.destination before ) training_info = { "id": training.id, "status": training.status, "destination": destination, "trigger_word": trigger_word, "lora_type": lora_type, "created_at": getattr(training, 'created_at', None), "urls": getattr(training, 'urls', {}), "input": training_input } print(f"โœ… Training started! ID: {training.id}") return training_info except Exception as e: raise Exception(f"Error starting training: {str(e)}") def get_training_status(self, training_id: str) -> Dict: """Get the status of a training""" try: training = self.client.trainings.get(training_id) return { "id": training.id, "status": training.status, "created_at": getattr(training, 'created_at', None), "completed_at": getattr(training, 'completed_at', None), "error": getattr(training, 'error', None), "logs": getattr(training, 'logs', None), "urls": getattr(training, 'urls', {}), "output": getattr(training, 'output', None) } except Exception as e: raise Exception(f"Error getting training status: {str(e)}") def wait_for_training_completion(self, training_id: str, max_wait_time: int = 3600, callback=None) -> Dict: """ Wait for training to complete Args: training_id: Training ID to monitor max_wait_time: Maximum time to wait in seconds (default 1 hour) callback: Optional callback function to call with status updates Returns: Final training status dict """ start_time = time.time() last_status = None while time.time() - start_time < max_wait_time: try: status = self.get_training_status(training_id) current_status = status.get('status', 'unknown') # Call callback if status changed if callback and current_status != last_status: callback(status) last_status = current_status if current_status == 'succeeded': print(f"โœ… Training completed successfully!") return status elif current_status == 'failed': error_msg = status.get('error', 'Unknown error occurred') raise Exception(f"Training failed: {error_msg}") elif current_status in ['canceled', 'cancelled']: raise Exception("Training was canceled") # Still processing, wait a bit time.sleep(30) # Check every 30 seconds for training except Exception as e: if "Training failed" in str(e) or "canceled" in str(e): raise # For other errors, continue waiting time.sleep(30) raise Exception(f"Training timed out after {max_wait_time} seconds") def list_user_models(self, username: str) -> List[Dict]: """List models for a user""" try: models = self.client.models.list() user_models = [] for model in models: if hasattr(model, 'owner') and model.owner == username: user_models.append({ "name": model.name, "owner": model.owner, "description": getattr(model, 'description', ''), "full_name": f"{model.owner}/{model.name}" }) return user_models except Exception as e: print(f"Error listing models: {e}") return [] def test_api(): """Test function to verify API functionality""" if not REPLICATE_AVAILABLE: print("โŒ Replicate package not installed. Run: pip install replicate") return False try: api = ReplicatePortraitAPI() print("โœ… API initialized successfully") return True except Exception as e: print(f"โŒ API initialization failed: {e}") return False def quick_test_with_sample(): """Quick test with a sample image URL""" if not REPLICATE_AVAILABLE: print("โŒ Replicate package not available") return try: # This is a quick test using the example from the docs output = replicate.run( "flux-kontext-apps/portrait-series", input={ "background": "black", "num_images": 2, # Small number for testing "input_image": "https://replicate.delivery/pbxt/N5DZJkCEuP5rWGtu8XcfyZj9sXzm4W3OXOSfdJnj9NmlirP2/mona-lisa.png", "output_format": "png", "randomize_images": True, "safety_tolerance": 1 } ) print(f"โœ… Test successful! Generated {len(output)} images") print("Sample URLs:", output[:2]) return True except Exception as e: print(f"โŒ Test failed: {e}") return False if __name__ == "__main__": print("๐ŸŽจ Replicate Portrait Series & Fast Flux Training API") print("=" * 50) # Test basic initialization if test_api(): print("\n๐Ÿš€ API ready to use!") # Optionally run a quick test (uncomment to test) # print("\n๐Ÿงช Running quick test...") # quick_test_with_sample() else: print("\n๐Ÿ“‹ Setup instructions:") print("1. pip install replicate") print("2. Set REPLICATE_API_TOKEN environment variable") print("3. Get token from: https://replicate.com/account")