Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple, Dict | |
| import requests | |
| import time | |
| from src.backblaze_storage import BB_uploadfile | |
| try: | |
| import replicate | |
| REPLICATE_AVAILABLE = True | |
| except ImportError: | |
| REPLICATE_AVAILABLE = False | |
| class ReplicatePortraitAPI: | |
| def __init__(self, api_token: Optional[str] = None): | |
| """Initialize Replicate API client""" | |
| if not REPLICATE_AVAILABLE: | |
| raise ImportError("Replicate package not installed. Run: pip install replicate") | |
| self.api_token = api_token or os.getenv('REPLICATE_API_TOKEN') | |
| if not self.api_token: | |
| raise ValueError("REPLICATE_API_TOKEN environment variable or api_token parameter is required") | |
| # Set the API token for the replicate client | |
| os.environ['REPLICATE_API_TOKEN'] = self.api_token | |
| self.portrait_model = "flux-kontext-apps/portrait-series" | |
| self.trainer_model = "replicate/fast-flux-trainer:8b10794665aed907bb98a1a5324cd1d3a8bea0e9b31e65210967fb9c9e2e08ed" | |
| # Initialize client | |
| self.client = replicate.Client(api_token=self.api_token) | |
| def upload_file_to_replicate(self, file_path: str) -> str: | |
| """Upload file to Replicate and get URL""" | |
| try: | |
| with open(file_path, 'rb') as file: | |
| uploaded_file = self.client.files.create(file) | |
| return uploaded_file.urls['get'] | |
| except Exception as e: | |
| # Fallback: convert to data URL for images only | |
| if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.gif')): | |
| import base64 | |
| with open(file_path, 'rb') as img_file: | |
| img_data = img_file.read() | |
| img_b64 = base64.b64encode(img_data).decode() | |
| # Determine MIME type | |
| ext = Path(file_path).suffix.lower() | |
| mime_types = { | |
| '.jpg': 'image/jpeg', | |
| '.jpeg': 'image/jpeg', | |
| '.png': 'image/png', | |
| '.webp': 'image/webp', | |
| '.gif': 'image/gif' | |
| } | |
| mime_type = mime_types.get(ext, 'image/jpeg') | |
| return f"data:{mime_type};base64,{img_b64}" | |
| else: | |
| raise Exception(f"Failed to upload file: {str(e)}") | |
| def download_images(self, image_urls: List[str], download_dir: str) -> List[str]: | |
| """Download images from URLs to local directory""" | |
| downloaded_paths = [] | |
| for i, url in enumerate(image_urls): | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| # Generate filename | |
| filename = f"portrait_{i+1:02d}.png" | |
| filepath = os.path.join(download_dir, filename) | |
| # Download image | |
| with open(filepath, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| downloaded_paths.append(filepath) | |
| except Exception as e: | |
| print(f"Error downloading image {i+1}: {e}") | |
| continue | |
| return downloaded_paths | |
| def generate_portrait_series(self, | |
| input_image_path: str, | |
| num_images: int = 4, | |
| background: str = "black", | |
| randomize_images: bool = True, | |
| output_format: str = "png", | |
| safety_tolerance: int = 1, | |
| download_dir: Optional[str] = None) -> Tuple[List[str], dict]: | |
| """ | |
| Generate portrait series using Replicate API | |
| Returns: | |
| Tuple of (downloaded_image_paths, api_response) | |
| """ | |
| # Create download directory if not provided | |
| if download_dir is None: | |
| download_dir = tempfile.mkdtemp(prefix="portrait_series_") | |
| else: | |
| os.makedirs(download_dir, exist_ok=True) | |
| try: | |
| # Upload input image | |
| image_url = self.upload_file_to_replicate(input_image_path) | |
| # Prepare input data | |
| input_data = { | |
| "input_image": image_url, | |
| "num_images": num_images, | |
| "background": background, | |
| "randomize_images": randomize_images, | |
| "output_format": output_format, | |
| "safety_tolerance": safety_tolerance | |
| } | |
| # Run the model - this handles everything automatically! | |
| print(f"π Running {self.portrait_model} with {num_images} images...") | |
| output = replicate.run( | |
| self.portrait_model, | |
| input=input_data | |
| ) | |
| # The output is a list of image URLs | |
| if not output: | |
| raise Exception("No output images generated") | |
| print(f"β Generated {len(output)} images, downloading...") | |
| # Download images | |
| downloaded_paths = self.download_images(output, download_dir) | |
| # Create response dict for compatibility | |
| response = { | |
| "output": output, | |
| "input": input_data, | |
| "status": "succeeded", | |
| "model": self.portrait_model | |
| } | |
| return downloaded_paths, response | |
| except Exception as e: | |
| raise Exception(f"Error in portrait generation: {str(e)}") | |
| def start_flux_training(self, | |
| input_images_zip: str, | |
| destination: str, | |
| trigger_word: str, | |
| lora_type: str = "subject") -> Dict: | |
| """ | |
| Start training a Fast Flux LoRA model | |
| Args: | |
| input_images_zip: Path to zip file containing training images OR URL to uploaded zip | |
| destination: Replicate model destination (username/model-name) | |
| trigger_word: Unique trigger word for the model | |
| lora_type: Type of training - "subject" or "style" | |
| Returns: | |
| Dict containing training information | |
| """ | |
| try: | |
| # Upload zip file if it's a local path | |
| print(f"π€ Uploading training data: {input_images_zip}") | |
| zip_url = BB_uploadfile(input_images_zip, os.path.basename(input_images_zip)) | |
| # Prepare training input | |
| training_input = { | |
| "input_images": zip_url, | |
| "trigger_word": trigger_word.lower(), | |
| "lora_type": lora_type.lower() | |
| } | |
| print(f"π Starting Fast Flux training...") | |
| print(f" Destination: {destination}") | |
| print(f" Trigger word: {trigger_word}") | |
| print(f" LoRA type: {lora_type}") | |
| # Create model if it doesn't exist | |
| try: | |
| owner, name = destination.split("/") | |
| model = self.client.models.create( | |
| owner=owner.lower(), | |
| name=name.lower(), | |
| visibility="public", | |
| hardware="gpu-a100-large" | |
| ) | |
| print(f"β Model created! ID: {model.id}") | |
| except Exception as e: | |
| error_message = f"Error creating model: {str(e)}" | |
| print(error_message) | |
| # Continue anyway in case model already exists | |
| # Create training | |
| model_name, version = self.trainer_model.split(":") | |
| # Fixed the main issue: use 'destination' parameter instead of 'self.destination' | |
| # Also fixed the typo: 'tranining' -> 'training' | |
| training = self.client.trainings.create( | |
| model=model_name.lower(), | |
| version=version.lower(), | |
| input=training_input, | |
| destination=destination # This was the main bug - was self.destination before | |
| ) | |
| training_info = { | |
| "id": training.id, | |
| "status": training.status, | |
| "destination": destination, | |
| "trigger_word": trigger_word, | |
| "lora_type": lora_type, | |
| "created_at": getattr(training, 'created_at', None), | |
| "urls": getattr(training, 'urls', {}), | |
| "input": training_input | |
| } | |
| print(f"β Training started! ID: {training.id}") | |
| return training_info | |
| except Exception as e: | |
| raise Exception(f"Error starting training: {str(e)}") | |
| def get_training_status(self, training_id: str) -> Dict: | |
| """Get the status of a training""" | |
| try: | |
| training = self.client.trainings.get(training_id) | |
| return { | |
| "id": training.id, | |
| "status": training.status, | |
| "created_at": getattr(training, 'created_at', None), | |
| "completed_at": getattr(training, 'completed_at', None), | |
| "error": getattr(training, 'error', None), | |
| "logs": getattr(training, 'logs', None), | |
| "urls": getattr(training, 'urls', {}), | |
| "output": getattr(training, 'output', None) | |
| } | |
| except Exception as e: | |
| raise Exception(f"Error getting training status: {str(e)}") | |
| def wait_for_training_completion(self, training_id: str, max_wait_time: int = 3600, callback=None) -> Dict: | |
| """ | |
| Wait for training to complete | |
| Args: | |
| training_id: Training ID to monitor | |
| max_wait_time: Maximum time to wait in seconds (default 1 hour) | |
| callback: Optional callback function to call with status updates | |
| Returns: | |
| Final training status dict | |
| """ | |
| start_time = time.time() | |
| last_status = None | |
| while time.time() - start_time < max_wait_time: | |
| try: | |
| status = self.get_training_status(training_id) | |
| current_status = status.get('status', 'unknown') | |
| # Call callback if status changed | |
| if callback and current_status != last_status: | |
| callback(status) | |
| last_status = current_status | |
| if current_status == 'succeeded': | |
| print(f"β Training completed successfully!") | |
| return status | |
| elif current_status == 'failed': | |
| error_msg = status.get('error', 'Unknown error occurred') | |
| raise Exception(f"Training failed: {error_msg}") | |
| elif current_status in ['canceled', 'cancelled']: | |
| raise Exception("Training was canceled") | |
| # Still processing, wait a bit | |
| time.sleep(30) # Check every 30 seconds for training | |
| except Exception as e: | |
| if "Training failed" in str(e) or "canceled" in str(e): | |
| raise | |
| # For other errors, continue waiting | |
| time.sleep(30) | |
| raise Exception(f"Training timed out after {max_wait_time} seconds") | |
| def list_user_models(self, username: str) -> List[Dict]: | |
| """List models for a user""" | |
| try: | |
| models = self.client.models.list() | |
| user_models = [] | |
| for model in models: | |
| if hasattr(model, 'owner') and model.owner == username: | |
| user_models.append({ | |
| "name": model.name, | |
| "owner": model.owner, | |
| "description": getattr(model, 'description', ''), | |
| "full_name": f"{model.owner}/{model.name}" | |
| }) | |
| return user_models | |
| except Exception as e: | |
| print(f"Error listing models: {e}") | |
| return [] | |
| def test_api(): | |
| """Test function to verify API functionality""" | |
| if not REPLICATE_AVAILABLE: | |
| print("β Replicate package not installed. Run: pip install replicate") | |
| return False | |
| try: | |
| api = ReplicatePortraitAPI() | |
| print("β API initialized successfully") | |
| return True | |
| except Exception as e: | |
| print(f"β API initialization failed: {e}") | |
| return False | |
| def quick_test_with_sample(): | |
| """Quick test with a sample image URL""" | |
| if not REPLICATE_AVAILABLE: | |
| print("β Replicate package not available") | |
| return | |
| try: | |
| # This is a quick test using the example from the docs | |
| output = replicate.run( | |
| "flux-kontext-apps/portrait-series", | |
| input={ | |
| "background": "black", | |
| "num_images": 2, # Small number for testing | |
| "input_image": "https://replicate.delivery/pbxt/N5DZJkCEuP5rWGtu8XcfyZj9sXzm4W3OXOSfdJnj9NmlirP2/mona-lisa.png", | |
| "output_format": "png", | |
| "randomize_images": True, | |
| "safety_tolerance": 1 | |
| } | |
| ) | |
| print(f"β Test successful! Generated {len(output)} images") | |
| print("Sample URLs:", output[:2]) | |
| return True | |
| except Exception as e: | |
| print(f"β Test failed: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| print("π¨ Replicate Portrait Series & Fast Flux Training API") | |
| print("=" * 50) | |
| # Test basic initialization | |
| if test_api(): | |
| print("\nπ API ready to use!") | |
| # Optionally run a quick test (uncomment to test) | |
| # print("\nπ§ͺ Running quick test...") | |
| # quick_test_with_sample() | |
| else: | |
| print("\nπ Setup instructions:") | |
| print("1. pip install replicate") | |
| print("2. Set REPLICATE_API_TOKEN environment variable") | |
| print("3. Get token from: https://replicate.com/account") |