Spaces:
Running
Running
| """ | |
| FAL API SeedDream v4 Edit - Web Application | |
| A Flask-based web interface for image editing using ByteDance's SeedDream model | |
| """ | |
| import os | |
| import json | |
| import requests | |
| import asyncio | |
| from flask import Flask, render_template, request, jsonify, send_from_directory | |
| from flask_cors import CORS | |
| import fal_client | |
| from werkzeug.utils import secure_filename | |
| import base64 | |
| import tempfile | |
| from pathlib import Path | |
| import uuid | |
| from threading import Thread | |
| from concurrent.futures import ThreadPoolExecutor | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Configuration | |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size | |
| app.config['UPLOAD_FOLDER'] = tempfile.gettempdir() | |
| # Ensure static and template directories exist | |
| Path("static").mkdir(exist_ok=True) | |
| Path("templates").mkdir(exist_ok=True) | |
| # Store active request handlers | |
| active_requests = {} | |
| # Thread pool for async operations | |
| executor = ThreadPoolExecutor(max_workers=10) | |
| # Background asyncio event loop for all async tasks (uploads and generation) | |
| _async_loop = asyncio.new_event_loop() | |
| def _async_loop_runner(): | |
| asyncio.set_event_loop(_async_loop) | |
| _async_loop.run_forever() | |
| _async_loop_thread = Thread(target=_async_loop_runner, name="AsyncLoopThread", daemon=True) | |
| _async_loop_thread.start() | |
| def run_in_async_loop(coro, timeout=None): | |
| """Run coroutine in the background event loop and wait for result.""" | |
| future = asyncio.run_coroutine_threadsafe(coro, _async_loop) | |
| return future.result(timeout) | |
| def schedule_in_async_loop(coro): | |
| """Schedule coroutine in the background event loop without waiting.""" | |
| return asyncio.run_coroutine_threadsafe(coro, _async_loop) | |
| def index(): | |
| """Serve the main HTML interface""" | |
| return render_template('index.html') | |
| def serve_static(filename): | |
| """Serve static files (CSS, JS)""" | |
| return send_from_directory('static', filename) | |
| async def process_fal_request(request_id, model_endpoint, fal_arguments, api_key): | |
| """Process FAL API request asynchronously""" | |
| handler = None | |
| try: | |
| print(f"[DEBUG] Starting async processing for request {request_id}") | |
| print(f"[DEBUG] Model endpoint: {model_endpoint}") | |
| print(f"[DEBUG] Arguments: {json.dumps(fal_arguments, indent=2)[:500]}...") # First 500 chars | |
| # Build per-request client with explicit key to avoid global env mutation | |
| print(f"[DEBUG] Creating AsyncClient with per-request key (not using env var)") | |
| client = fal_client.AsyncClient(key=api_key) | |
| # Submit the request asynchronously using the per-request client | |
| print(f"[DEBUG] Submitting to FAL API via AsyncClient.submit...") | |
| handler = await client.submit( | |
| model_endpoint, | |
| arguments=fal_arguments, | |
| ) | |
| print(f"[DEBUG] Handler created: {handler}") | |
| # Store handler info | |
| active_requests[request_id] = { | |
| 'handler': handler, | |
| 'status': 'processing', | |
| 'logs': [], | |
| 'result': None | |
| } | |
| # Collect logs asynchronously with better error handling | |
| print(f"[DEBUG] Starting to collect events...") | |
| event_count = 0 | |
| max_events = 1000 # Prevent infinite loops | |
| try: | |
| async for event in handler.iter_events(with_logs=True): | |
| event_count += 1 | |
| print(f"[DEBUG] Event #{event_count}: {type(event).__name__}") | |
| if hasattr(event, 'logs') and event.logs: | |
| for log in event.logs: | |
| log_msg = log.get("message", "") | |
| print(f"[DEBUG] Log: {log_msg}") | |
| active_requests[request_id]['logs'].append(log_msg) | |
| # Safety check to prevent infinite loops | |
| if event_count >= max_events: | |
| print(f"[WARNING] Max events ({max_events}) reached, breaking loop") | |
| break | |
| except asyncio.CancelledError: | |
| print(f"[WARNING] Event iteration cancelled for request {request_id}") | |
| raise | |
| except RuntimeError as e: | |
| if "Event loop is closed" in str(e): | |
| print(f"[WARNING] Event loop closed during iteration, attempting to continue...") | |
| else: | |
| print(f"[WARNING] Runtime error during event iteration: {e}") | |
| except Exception as iter_error: | |
| print(f"[WARNING] Error during event iteration: {iter_error}") | |
| # Continue to try to get the result even if event iteration fails | |
| print(f"[DEBUG] Total events collected: {event_count}") | |
| # Get the final result with timeout to prevent hanging | |
| print(f"[DEBUG] Getting final result...") | |
| try: | |
| timeout_seconds = 300 if ('text-to-video' in model_endpoint or 'image-to-video' in model_endpoint) else 60 | |
| result = await asyncio.wait_for(handler.get(), timeout=timeout_seconds) | |
| except asyncio.TimeoutError: | |
| print(f"[ERROR] Timeout waiting for result") | |
| raise Exception("Timeout waiting for FAL API result") | |
| print(f"[DEBUG] Result received: {json.dumps(result, indent=2)[:500] if result else 'None'}...") | |
| # Update request status | |
| active_requests[request_id]['status'] = 'completed' | |
| active_requests[request_id]['result'] = result | |
| # Add logs to result | |
| if active_requests[request_id]['logs']: | |
| result['logs'] = active_requests[request_id]['logs'] | |
| print(f"[DEBUG] Added {len(active_requests[request_id]['logs'])} logs to result") | |
| print(f"[DEBUG] Async processing completed for request {request_id}") | |
| return result | |
| except asyncio.CancelledError: | |
| print(f"[WARNING] Task cancelled for request {request_id}") | |
| active_requests[request_id]['status'] = 'error' | |
| active_requests[request_id]['error'] = 'Task was cancelled' | |
| raise | |
| except Exception as e: | |
| print(f"[ERROR] Exception in process_fal_request: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| active_requests[request_id]['status'] = 'error' | |
| active_requests[request_id]['error'] = str(e) | |
| raise | |
| def run_async_task(request_id, model_endpoint, fal_arguments, api_key): | |
| """Run async task with proper event loop management""" | |
| print(f"[DEBUG run_async_task] Starting for request {request_id}") | |
| # Create and run event loop in a way that prevents premature closure | |
| loop = None | |
| try: | |
| # Create a new event loop for this thread | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| # Enable debug mode to get better error messages | |
| loop.set_debug(False) # Set to True for more verbose debugging | |
| # Run the async task in the new loop | |
| result = loop.run_until_complete(process_fal_request(request_id, model_endpoint, fal_arguments, api_key)) | |
| print(f"[DEBUG run_async_task] Completed with result keys: {result.keys() if result else 'None'}") | |
| # Ensure all async generators are closed properly | |
| loop.run_until_complete(asyncio.sleep(0)) # Process any remaining callbacks | |
| except Exception as e: | |
| print(f"[ERROR run_async_task] Failed: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| # Update the request status to error | |
| if request_id in active_requests: | |
| active_requests[request_id]['status'] = 'error' | |
| active_requests[request_id]['error'] = str(e) | |
| finally: | |
| # More careful cleanup to avoid closing loop while still in use | |
| if loop: | |
| try: | |
| # Give time for any final async operations to complete | |
| if not loop.is_closed(): | |
| # Run a small delay to let any final callbacks execute | |
| loop.run_until_complete(asyncio.sleep(0.5)) | |
| # Get all tasks that are still pending | |
| pending = asyncio.all_tasks(loop) | |
| # Cancel them gracefully | |
| for task in pending: | |
| task.cancel() | |
| # Wait for all tasks to be cancelled, with a timeout | |
| if pending: | |
| loop.run_until_complete( | |
| asyncio.wait_for( | |
| asyncio.gather(*pending, return_exceptions=True), | |
| timeout=5.0 | |
| ) | |
| ) | |
| except asyncio.TimeoutError: | |
| print(f"[WARNING] Timeout during task cleanup") | |
| except Exception as cleanup_error: | |
| print(f"[WARNING] Error during loop cleanup: {cleanup_error}") | |
| finally: | |
| # Close the loop only after everything is done | |
| try: | |
| if not loop.is_closed(): | |
| # Stop the loop first to ensure it's not running | |
| loop.stop() | |
| # Give it a moment to stop | |
| loop.run_until_complete(asyncio.sleep(0)) | |
| # Now close it | |
| loop.close() | |
| except Exception as close_error: | |
| print(f"[WARNING] Error closing loop: {close_error}") | |
| finally: | |
| # Clear the event loop from the thread | |
| asyncio.set_event_loop(None) | |
| def generate(): | |
| """Handle image generation requests (non-blocking)""" | |
| print("[DEBUG /api/generate] Endpoint called") | |
| try: | |
| # Get request data | |
| data = request.json | |
| print(f"[DEBUG /api/generate] Request data keys: {data.keys() if data else 'None'}") | |
| # Get model endpoint from header or default to edit | |
| model_endpoint = request.headers.get('X-Model-Endpoint', 'fal-ai/bytedance/seedream/v4/edit') | |
| print(f"[DEBUG /api/generate] Model endpoint: {model_endpoint}") | |
| # Get API key from header or environment | |
| auth_header = request.headers.get('Authorization', '') | |
| if auth_header.startswith('Bearer '): | |
| api_key = auth_header.replace('Bearer ', '') | |
| elif os.environ.get('FAL_KEY'): | |
| api_key = os.environ.get('FAL_KEY') | |
| else: | |
| return jsonify({'error': 'API key not provided'}), 401 | |
| # Prepare arguments for FAL API | |
| fal_arguments = { | |
| 'prompt': data.get('prompt') | |
| } | |
| # Handle model-specific parameters | |
| is_text_to_image = 'text-to-image' in model_endpoint | |
| is_text_to_video = 'text-to-video' in model_endpoint | |
| is_image_to_video = 'image-to-video' in model_endpoint | |
| if is_text_to_image: | |
| # Image generation (text-to-image) | |
| if 'image_size' in data: | |
| fal_arguments['image_size'] = data['image_size'] | |
| if 'num_images' in data: | |
| fal_arguments['num_images'] = data['num_images'] | |
| elif is_text_to_video: | |
| # Video generation (text-to-video) | |
| for k in ['aspect_ratio', 'resolution', 'duration', 'camera_fixed']: | |
| if k in data: | |
| fal_arguments[k] = data[k] | |
| elif is_image_to_video: | |
| # Video generation (image-to-video) - single image_url | |
| image_url = data.get('image_url') | |
| if not image_url: | |
| # Fallback: use first from image_urls if provided | |
| urls = data.get('image_urls', []) | |
| if urls: | |
| image_url = urls[0] | |
| if image_url: | |
| fal_arguments['image_url'] = image_url | |
| for k in ['aspect_ratio', 'resolution', 'duration', 'camera_fixed']: | |
| if k in data: | |
| fal_arguments[k] = data[k] | |
| else: | |
| # Image edit mode (default) | |
| if 'image_size' in data: | |
| fal_arguments['image_size'] = data['image_size'] | |
| if 'num_images' in data: | |
| fal_arguments['num_images'] = data['num_images'] | |
| processed_image_urls = [] | |
| for url in data.get('image_urls', []): | |
| if url.startswith('data:'): | |
| # Handle base64 data URLs | |
| processed_image_urls.append(url) | |
| else: | |
| # Regular URL | |
| processed_image_urls.append(url) | |
| fal_arguments['image_urls'] = processed_image_urls[:10] # Max 10 images | |
| # Add max_images for edit mode | |
| if 'max_images' in data: | |
| fal_arguments['max_images'] = data['max_images'] | |
| # Add shared optional parameters | |
| if 'seed' in data: | |
| fal_arguments['seed'] = data['seed'] | |
| if 'enable_safety_checker' in data: | |
| fal_arguments['enable_safety_checker'] = data['enable_safety_checker'] | |
| # Generate unique request ID | |
| request_id = str(uuid.uuid4()) | |
| print(f"[DEBUG /api/generate] Generated request ID: {request_id}") | |
| # Initialize request tracking | |
| active_requests[request_id] = { | |
| 'status': 'submitted', | |
| 'logs': [], | |
| 'result': None | |
| } | |
| print(f"[DEBUG /api/generate] Request tracking initialized") | |
| # Schedule async processing on background event loop | |
| future = schedule_in_async_loop(process_fal_request(request_id, model_endpoint, fal_arguments, api_key)) | |
| def _bg_done_callback(fut, req_id=request_id): | |
| try: | |
| fut.result() | |
| except Exception as e: | |
| print(f"[ERROR /api/generate] Background task failed for {req_id}: {e}") | |
| future.add_done_callback(_bg_done_callback) | |
| print(f"[DEBUG /api/generate] Background task scheduled on shared loop") | |
| # Return request ID immediately (non-blocking) | |
| response_data = { | |
| 'request_id': request_id, | |
| 'status': 'submitted', | |
| 'message': 'Request submitted successfully' | |
| } | |
| print(f"[DEBUG /api/generate] Returning response: {response_data}") | |
| return jsonify(response_data), 202 | |
| except Exception as e: | |
| print(f"Error in generate endpoint: {str(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| def check_status(request_id): | |
| """Check the status of a generation request""" | |
| print(f"[DEBUG /api/status] Checking status for request {request_id}") | |
| if request_id not in active_requests: | |
| print(f"[DEBUG /api/status] Request {request_id} not found in active_requests") | |
| print(f"[DEBUG /api/status] Active request IDs: {list(active_requests.keys())}") | |
| return jsonify({'error': 'Request not found'}), 404 | |
| request_info = active_requests[request_id] | |
| print(f"[DEBUG /api/status] Request status: {request_info['status']}") | |
| print(f"[DEBUG /api/status] Request has {len(request_info.get('logs', []))} logs") | |
| response = { | |
| 'request_id': request_id, | |
| 'status': request_info['status'], | |
| 'logs': request_info.get('logs', []) | |
| } | |
| if request_info['status'] == 'completed': | |
| result = request_info['result'] | |
| print(f"[DEBUG /api/status] Request completed, result keys: {result.keys() if result else 'None'}") | |
| if result and 'images' in result: | |
| print(f"[DEBUG /api/status] Found {len(result['images'])} images in result") | |
| for i, img in enumerate(result['images']): | |
| print(f"[DEBUG /api/status] Image {i+1} keys: {img.keys() if isinstance(img, dict) else type(img)}") | |
| response['result'] = result | |
| # Clean up completed request after retrieval | |
| del active_requests[request_id] | |
| print(f"[DEBUG /api/status] Cleaned up completed request {request_id}") | |
| elif request_info['status'] == 'error': | |
| error_msg = request_info.get('error', 'Unknown error') | |
| print(f"[DEBUG /api/status] Request error: {error_msg}") | |
| response['error'] = error_msg | |
| # Clean up errored request after retrieval | |
| del active_requests[request_id] | |
| print(f"[DEBUG /api/status] Cleaned up errored request {request_id}") | |
| print(f"[DEBUG /api/status] Returning response with status: {response['status']}") | |
| return jsonify(response), 200 | |
| def upload_file(): | |
| """Handle file uploads and return data URL""" | |
| try: | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file provided'}), 400 | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'No file selected'}), 400 | |
| # Read file and convert to base64 data URL | |
| file_content = file.read() | |
| file_type = file.content_type or 'application/octet-stream' | |
| base64_content = base64.b64encode(file_content).decode('utf-8') | |
| data_url = f"data:{file_type};base64,{base64_content}" | |
| return jsonify({'url': data_url}), 200 | |
| except Exception as e: | |
| print(f"Error in upload endpoint: {str(e)}") | |
| return jsonify({'error': str(e)}), 500 | |
| async def upload_file_to_fal_async(file_path, api_key): | |
| """Helper function to upload a file to FAL storage""" | |
| # Use per-request AsyncClient with explicit key to avoid global env mutation | |
| client = fal_client.AsyncClient(key=api_key) | |
| return await client.upload_file(file_path) | |
| def upload_to_fal_sync(file_path, api_key): | |
| """Run upload coroutine on the shared background event loop""" | |
| try: | |
| # 5 min timeout to accommodate large uploads if needed | |
| return run_in_async_loop(upload_file_to_fal_async(file_path, api_key), timeout=300) | |
| except Exception as e: | |
| print(f"[ERROR upload_to_fal_sync] Upload failed: {str(e)}") | |
| raise | |
| def upload_to_fal(): | |
| """Upload base64 image data to FAL storage and return the URL""" | |
| try: | |
| data = request.json | |
| if 'image_data' not in data: | |
| return jsonify({'error': 'No image data provided'}), 400 | |
| # Get API key from header or environment | |
| auth_header = request.headers.get('Authorization', '') | |
| if auth_header.startswith('Bearer '): | |
| api_key = auth_header.replace('Bearer ', '') | |
| elif os.environ.get('FAL_KEY'): | |
| api_key = os.environ.get('FAL_KEY') | |
| else: | |
| return jsonify({'error': 'API key not provided'}), 401 | |
| image_data = data['image_data'] | |
| # If it's a base64 data URL, extract the actual base64 content | |
| if image_data.startswith('data:'): | |
| # Extract base64 content from data URL | |
| header, base64_content = image_data.split(',', 1) | |
| # Decode base64 to bytes | |
| image_bytes = base64.b64decode(base64_content) | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: | |
| tmp_file.write(image_bytes) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Use the synchronous wrapper which handles event loop properly | |
| fal_url = upload_to_fal_sync(tmp_file_path, api_key) | |
| print(f"[DEBUG] Uploaded to FAL: {fal_url}") | |
| return jsonify({'url': fal_url}), 200 | |
| finally: | |
| # Clean up temporary file | |
| try: | |
| os.unlink(tmp_file_path) | |
| except: | |
| pass | |
| else: | |
| # If it's already a URL, return it as-is | |
| return jsonify({'url': image_data}), 200 | |
| except Exception as e: | |
| print(f"Error uploading to FAL: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return jsonify({'error': str(e)}), 500 | |
| def health_check(): | |
| """Health check endpoint for container monitoring""" | |
| return jsonify({'status': 'healthy'}), 200 | |
| if __name__ == '__main__': | |
| # Get port from environment or default to 7860 (Hugging Face Spaces default) | |
| port = int(os.environ.get('PORT', 7860)) | |
| # Check if running in production (Hugging Face Spaces) | |
| is_production = os.environ.get('SPACE_ID') is not None | |
| # Run the application | |
| if is_production: | |
| app.run(host='0.0.0.0', port=port, debug=False) | |
| else: | |
| app.run(host='0.0.0.0', port=port, debug=True) |