""" FAL API SeedDream v4 Edit - Web Application A Flask-based web interface for image editing using ByteDance's SeedDream model """ import os import json import requests import asyncio from flask import Flask, render_template, request, jsonify, send_from_directory from flask_cors import CORS import fal_client from werkzeug.utils import secure_filename import base64 import tempfile from pathlib import Path import uuid from threading import Thread from concurrent.futures import ThreadPoolExecutor app = Flask(__name__) CORS(app) # Configuration app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size app.config['UPLOAD_FOLDER'] = tempfile.gettempdir() # Ensure static and template directories exist Path("static").mkdir(exist_ok=True) Path("templates").mkdir(exist_ok=True) # Store active request handlers active_requests = {} # Thread pool for async operations executor = ThreadPoolExecutor(max_workers=10) # Background asyncio event loop for all async tasks (uploads and generation) _async_loop = asyncio.new_event_loop() def _async_loop_runner(): asyncio.set_event_loop(_async_loop) _async_loop.run_forever() _async_loop_thread = Thread(target=_async_loop_runner, name="AsyncLoopThread", daemon=True) _async_loop_thread.start() def run_in_async_loop(coro, timeout=None): """Run coroutine in the background event loop and wait for result.""" future = asyncio.run_coroutine_threadsafe(coro, _async_loop) return future.result(timeout) def schedule_in_async_loop(coro): """Schedule coroutine in the background event loop without waiting.""" return asyncio.run_coroutine_threadsafe(coro, _async_loop) @app.route('/') def index(): """Serve the main HTML interface""" return render_template('index.html') @app.route('/static/') def serve_static(filename): """Serve static files (CSS, JS)""" return send_from_directory('static', filename) async def process_fal_request(request_id, model_endpoint, fal_arguments, api_key): """Process FAL API request asynchronously""" handler = None try: print(f"[DEBUG] Starting async processing for request {request_id}") print(f"[DEBUG] Model endpoint: {model_endpoint}") print(f"[DEBUG] Arguments: {json.dumps(fal_arguments, indent=2)[:500]}...") # First 500 chars # Build per-request client with explicit key to avoid global env mutation print(f"[DEBUG] Creating AsyncClient with per-request key (not using env var)") client = fal_client.AsyncClient(key=api_key) # Submit the request asynchronously using the per-request client print(f"[DEBUG] Submitting to FAL API via AsyncClient.submit...") handler = await client.submit( model_endpoint, arguments=fal_arguments, ) print(f"[DEBUG] Handler created: {handler}") # Store handler info active_requests[request_id] = { 'handler': handler, 'status': 'processing', 'logs': [], 'result': None } # Collect logs asynchronously with better error handling print(f"[DEBUG] Starting to collect events...") event_count = 0 max_events = 1000 # Prevent infinite loops try: async for event in handler.iter_events(with_logs=True): event_count += 1 print(f"[DEBUG] Event #{event_count}: {type(event).__name__}") if hasattr(event, 'logs') and event.logs: for log in event.logs: log_msg = log.get("message", "") print(f"[DEBUG] Log: {log_msg}") active_requests[request_id]['logs'].append(log_msg) # Safety check to prevent infinite loops if event_count >= max_events: print(f"[WARNING] Max events ({max_events}) reached, breaking loop") break except asyncio.CancelledError: print(f"[WARNING] Event iteration cancelled for request {request_id}") raise except RuntimeError as e: if "Event loop is closed" in str(e): print(f"[WARNING] Event loop closed during iteration, attempting to continue...") else: print(f"[WARNING] Runtime error during event iteration: {e}") except Exception as iter_error: print(f"[WARNING] Error during event iteration: {iter_error}") # Continue to try to get the result even if event iteration fails print(f"[DEBUG] Total events collected: {event_count}") # Get the final result with timeout to prevent hanging print(f"[DEBUG] Getting final result...") try: timeout_seconds = 300 if ('text-to-video' in model_endpoint or 'image-to-video' in model_endpoint) else 60 result = await asyncio.wait_for(handler.get(), timeout=timeout_seconds) except asyncio.TimeoutError: print(f"[ERROR] Timeout waiting for result") raise Exception("Timeout waiting for FAL API result") print(f"[DEBUG] Result received: {json.dumps(result, indent=2)[:500] if result else 'None'}...") # Update request status active_requests[request_id]['status'] = 'completed' active_requests[request_id]['result'] = result # Add logs to result if active_requests[request_id]['logs']: result['logs'] = active_requests[request_id]['logs'] print(f"[DEBUG] Added {len(active_requests[request_id]['logs'])} logs to result") print(f"[DEBUG] Async processing completed for request {request_id}") return result except asyncio.CancelledError: print(f"[WARNING] Task cancelled for request {request_id}") active_requests[request_id]['status'] = 'error' active_requests[request_id]['error'] = 'Task was cancelled' raise except Exception as e: print(f"[ERROR] Exception in process_fal_request: {str(e)}") import traceback traceback.print_exc() active_requests[request_id]['status'] = 'error' active_requests[request_id]['error'] = str(e) raise def run_async_task(request_id, model_endpoint, fal_arguments, api_key): """Run async task with proper event loop management""" print(f"[DEBUG run_async_task] Starting for request {request_id}") # Create and run event loop in a way that prevents premature closure loop = None try: # Create a new event loop for this thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) # Enable debug mode to get better error messages loop.set_debug(False) # Set to True for more verbose debugging # Run the async task in the new loop result = loop.run_until_complete(process_fal_request(request_id, model_endpoint, fal_arguments, api_key)) print(f"[DEBUG run_async_task] Completed with result keys: {result.keys() if result else 'None'}") # Ensure all async generators are closed properly loop.run_until_complete(asyncio.sleep(0)) # Process any remaining callbacks except Exception as e: print(f"[ERROR run_async_task] Failed: {str(e)}") import traceback traceback.print_exc() # Update the request status to error if request_id in active_requests: active_requests[request_id]['status'] = 'error' active_requests[request_id]['error'] = str(e) finally: # More careful cleanup to avoid closing loop while still in use if loop: try: # Give time for any final async operations to complete if not loop.is_closed(): # Run a small delay to let any final callbacks execute loop.run_until_complete(asyncio.sleep(0.5)) # Get all tasks that are still pending pending = asyncio.all_tasks(loop) # Cancel them gracefully for task in pending: task.cancel() # Wait for all tasks to be cancelled, with a timeout if pending: loop.run_until_complete( asyncio.wait_for( asyncio.gather(*pending, return_exceptions=True), timeout=5.0 ) ) except asyncio.TimeoutError: print(f"[WARNING] Timeout during task cleanup") except Exception as cleanup_error: print(f"[WARNING] Error during loop cleanup: {cleanup_error}") finally: # Close the loop only after everything is done try: if not loop.is_closed(): # Stop the loop first to ensure it's not running loop.stop() # Give it a moment to stop loop.run_until_complete(asyncio.sleep(0)) # Now close it loop.close() except Exception as close_error: print(f"[WARNING] Error closing loop: {close_error}") finally: # Clear the event loop from the thread asyncio.set_event_loop(None) @app.route('/api/generate', methods=['POST']) def generate(): """Handle image generation requests (non-blocking)""" print("[DEBUG /api/generate] Endpoint called") try: # Get request data data = request.json print(f"[DEBUG /api/generate] Request data keys: {data.keys() if data else 'None'}") # Get model endpoint from header or default to edit model_endpoint = request.headers.get('X-Model-Endpoint', 'fal-ai/bytedance/seedream/v4/edit') print(f"[DEBUG /api/generate] Model endpoint: {model_endpoint}") # Get API key from header or environment auth_header = request.headers.get('Authorization', '') if auth_header.startswith('Bearer '): api_key = auth_header.replace('Bearer ', '') elif os.environ.get('FAL_KEY'): api_key = os.environ.get('FAL_KEY') else: return jsonify({'error': 'API key not provided'}), 401 # Prepare arguments for FAL API fal_arguments = { 'prompt': data.get('prompt') } # Handle model-specific parameters is_text_to_image = 'text-to-image' in model_endpoint is_text_to_video = 'text-to-video' in model_endpoint is_image_to_video = 'image-to-video' in model_endpoint if is_text_to_image: # Image generation (text-to-image) if 'image_size' in data: fal_arguments['image_size'] = data['image_size'] if 'num_images' in data: fal_arguments['num_images'] = data['num_images'] elif is_text_to_video: # Video generation (text-to-video) for k in ['aspect_ratio', 'resolution', 'duration', 'camera_fixed']: if k in data: fal_arguments[k] = data[k] elif is_image_to_video: # Video generation (image-to-video) - single image_url image_url = data.get('image_url') if not image_url: # Fallback: use first from image_urls if provided urls = data.get('image_urls', []) if urls: image_url = urls[0] if image_url: fal_arguments['image_url'] = image_url for k in ['aspect_ratio', 'resolution', 'duration', 'camera_fixed']: if k in data: fal_arguments[k] = data[k] else: # Image edit mode (default) if 'image_size' in data: fal_arguments['image_size'] = data['image_size'] if 'num_images' in data: fal_arguments['num_images'] = data['num_images'] processed_image_urls = [] for url in data.get('image_urls', []): if url.startswith('data:'): # Handle base64 data URLs processed_image_urls.append(url) else: # Regular URL processed_image_urls.append(url) fal_arguments['image_urls'] = processed_image_urls[:10] # Max 10 images # Add max_images for edit mode if 'max_images' in data: fal_arguments['max_images'] = data['max_images'] # Add shared optional parameters if 'seed' in data: fal_arguments['seed'] = data['seed'] if 'enable_safety_checker' in data: fal_arguments['enable_safety_checker'] = data['enable_safety_checker'] # Generate unique request ID request_id = str(uuid.uuid4()) print(f"[DEBUG /api/generate] Generated request ID: {request_id}") # Initialize request tracking active_requests[request_id] = { 'status': 'submitted', 'logs': [], 'result': None } print(f"[DEBUG /api/generate] Request tracking initialized") # Schedule async processing on background event loop future = schedule_in_async_loop(process_fal_request(request_id, model_endpoint, fal_arguments, api_key)) def _bg_done_callback(fut, req_id=request_id): try: fut.result() except Exception as e: print(f"[ERROR /api/generate] Background task failed for {req_id}: {e}") future.add_done_callback(_bg_done_callback) print(f"[DEBUG /api/generate] Background task scheduled on shared loop") # Return request ID immediately (non-blocking) response_data = { 'request_id': request_id, 'status': 'submitted', 'message': 'Request submitted successfully' } print(f"[DEBUG /api/generate] Returning response: {response_data}") return jsonify(response_data), 202 except Exception as e: print(f"Error in generate endpoint: {str(e)}") return jsonify({'error': str(e)}), 500 @app.route('/api/status/', methods=['GET']) def check_status(request_id): """Check the status of a generation request""" print(f"[DEBUG /api/status] Checking status for request {request_id}") if request_id not in active_requests: print(f"[DEBUG /api/status] Request {request_id} not found in active_requests") print(f"[DEBUG /api/status] Active request IDs: {list(active_requests.keys())}") return jsonify({'error': 'Request not found'}), 404 request_info = active_requests[request_id] print(f"[DEBUG /api/status] Request status: {request_info['status']}") print(f"[DEBUG /api/status] Request has {len(request_info.get('logs', []))} logs") response = { 'request_id': request_id, 'status': request_info['status'], 'logs': request_info.get('logs', []) } if request_info['status'] == 'completed': result = request_info['result'] print(f"[DEBUG /api/status] Request completed, result keys: {result.keys() if result else 'None'}") if result and 'images' in result: print(f"[DEBUG /api/status] Found {len(result['images'])} images in result") for i, img in enumerate(result['images']): print(f"[DEBUG /api/status] Image {i+1} keys: {img.keys() if isinstance(img, dict) else type(img)}") response['result'] = result # Clean up completed request after retrieval del active_requests[request_id] print(f"[DEBUG /api/status] Cleaned up completed request {request_id}") elif request_info['status'] == 'error': error_msg = request_info.get('error', 'Unknown error') print(f"[DEBUG /api/status] Request error: {error_msg}") response['error'] = error_msg # Clean up errored request after retrieval del active_requests[request_id] print(f"[DEBUG /api/status] Cleaned up errored request {request_id}") print(f"[DEBUG /api/status] Returning response with status: {response['status']}") return jsonify(response), 200 @app.route('/api/upload', methods=['POST']) def upload_file(): """Handle file uploads and return data URL""" try: if 'file' not in request.files: return jsonify({'error': 'No file provided'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'No file selected'}), 400 # Read file and convert to base64 data URL file_content = file.read() file_type = file.content_type or 'application/octet-stream' base64_content = base64.b64encode(file_content).decode('utf-8') data_url = f"data:{file_type};base64,{base64_content}" return jsonify({'url': data_url}), 200 except Exception as e: print(f"Error in upload endpoint: {str(e)}") return jsonify({'error': str(e)}), 500 async def upload_file_to_fal_async(file_path, api_key): """Helper function to upload a file to FAL storage""" # Use per-request AsyncClient with explicit key to avoid global env mutation client = fal_client.AsyncClient(key=api_key) return await client.upload_file(file_path) def upload_to_fal_sync(file_path, api_key): """Run upload coroutine on the shared background event loop""" try: # 5 min timeout to accommodate large uploads if needed return run_in_async_loop(upload_file_to_fal_async(file_path, api_key), timeout=300) except Exception as e: print(f"[ERROR upload_to_fal_sync] Upload failed: {str(e)}") raise @app.route('/api/upload-to-fal', methods=['POST']) def upload_to_fal(): """Upload base64 image data to FAL storage and return the URL""" try: data = request.json if 'image_data' not in data: return jsonify({'error': 'No image data provided'}), 400 # Get API key from header or environment auth_header = request.headers.get('Authorization', '') if auth_header.startswith('Bearer '): api_key = auth_header.replace('Bearer ', '') elif os.environ.get('FAL_KEY'): api_key = os.environ.get('FAL_KEY') else: return jsonify({'error': 'API key not provided'}), 401 image_data = data['image_data'] # If it's a base64 data URL, extract the actual base64 content if image_data.startswith('data:'): # Extract base64 content from data URL header, base64_content = image_data.split(',', 1) # Decode base64 to bytes image_bytes = base64.b64decode(base64_content) # Save to temporary file with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: tmp_file.write(image_bytes) tmp_file_path = tmp_file.name try: # Use the synchronous wrapper which handles event loop properly fal_url = upload_to_fal_sync(tmp_file_path, api_key) print(f"[DEBUG] Uploaded to FAL: {fal_url}") return jsonify({'url': fal_url}), 200 finally: # Clean up temporary file try: os.unlink(tmp_file_path) except: pass else: # If it's already a URL, return it as-is return jsonify({'url': image_data}), 200 except Exception as e: print(f"Error uploading to FAL: {str(e)}") import traceback traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint for container monitoring""" return jsonify({'status': 'healthy'}), 200 if __name__ == '__main__': # Get port from environment or default to 7860 (Hugging Face Spaces default) port = int(os.environ.get('PORT', 7860)) # Check if running in production (Hugging Face Spaces) is_production = os.environ.get('SPACE_ID') is not None # Run the application if is_production: app.run(host='0.0.0.0', port=port, debug=False) else: app.run(host='0.0.0.0', port=port, debug=True)