fal / app.py
bibibi12345's picture
fixed resolution
58ff4d4
"""
FAL API SeedDream v4 Edit - Web Application
A Flask-based web interface for image editing using ByteDance's SeedDream model
"""
import os
import json
import requests
import asyncio
from flask import Flask, render_template, request, jsonify, send_from_directory
from flask_cors import CORS
import fal_client
from werkzeug.utils import secure_filename
import base64
import tempfile
from pathlib import Path
import uuid
from threading import Thread
from concurrent.futures import ThreadPoolExecutor
app = Flask(__name__)
CORS(app)
# Configuration
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max file size
app.config['UPLOAD_FOLDER'] = tempfile.gettempdir()
# Ensure static and template directories exist
Path("static").mkdir(exist_ok=True)
Path("templates").mkdir(exist_ok=True)
# Store active request handlers
active_requests = {}
# Thread pool for async operations
executor = ThreadPoolExecutor(max_workers=10)
# Background asyncio event loop for all async tasks (uploads and generation)
_async_loop = asyncio.new_event_loop()
def _async_loop_runner():
asyncio.set_event_loop(_async_loop)
_async_loop.run_forever()
_async_loop_thread = Thread(target=_async_loop_runner, name="AsyncLoopThread", daemon=True)
_async_loop_thread.start()
def run_in_async_loop(coro, timeout=None):
"""Run coroutine in the background event loop and wait for result."""
future = asyncio.run_coroutine_threadsafe(coro, _async_loop)
return future.result(timeout)
def schedule_in_async_loop(coro):
"""Schedule coroutine in the background event loop without waiting."""
return asyncio.run_coroutine_threadsafe(coro, _async_loop)
@app.route('/')
def index():
"""Serve the main HTML interface"""
return render_template('index.html')
@app.route('/static/<path:filename>')
def serve_static(filename):
"""Serve static files (CSS, JS)"""
return send_from_directory('static', filename)
async def process_fal_request(request_id, model_endpoint, fal_arguments, api_key):
"""Process FAL API request asynchronously"""
handler = None
try:
print(f"[DEBUG] Starting async processing for request {request_id}")
print(f"[DEBUG] Model endpoint: {model_endpoint}")
print(f"[DEBUG] Arguments: {json.dumps(fal_arguments, indent=2)[:500]}...") # First 500 chars
# Build per-request client with explicit key to avoid global env mutation
print(f"[DEBUG] Creating AsyncClient with per-request key (not using env var)")
client = fal_client.AsyncClient(key=api_key)
# Submit the request asynchronously using the per-request client
print(f"[DEBUG] Submitting to FAL API via AsyncClient.submit...")
handler = await client.submit(
model_endpoint,
arguments=fal_arguments,
)
print(f"[DEBUG] Handler created: {handler}")
# Store handler info
active_requests[request_id] = {
'handler': handler,
'status': 'processing',
'logs': [],
'result': None
}
# Collect logs asynchronously with better error handling
print(f"[DEBUG] Starting to collect events...")
event_count = 0
max_events = 1000 # Prevent infinite loops
try:
async for event in handler.iter_events(with_logs=True):
event_count += 1
print(f"[DEBUG] Event #{event_count}: {type(event).__name__}")
if hasattr(event, 'logs') and event.logs:
for log in event.logs:
log_msg = log.get("message", "")
print(f"[DEBUG] Log: {log_msg}")
active_requests[request_id]['logs'].append(log_msg)
# Safety check to prevent infinite loops
if event_count >= max_events:
print(f"[WARNING] Max events ({max_events}) reached, breaking loop")
break
except asyncio.CancelledError:
print(f"[WARNING] Event iteration cancelled for request {request_id}")
raise
except RuntimeError as e:
if "Event loop is closed" in str(e):
print(f"[WARNING] Event loop closed during iteration, attempting to continue...")
else:
print(f"[WARNING] Runtime error during event iteration: {e}")
except Exception as iter_error:
print(f"[WARNING] Error during event iteration: {iter_error}")
# Continue to try to get the result even if event iteration fails
print(f"[DEBUG] Total events collected: {event_count}")
# Get the final result with timeout to prevent hanging
print(f"[DEBUG] Getting final result...")
try:
timeout_seconds = 300 if ('text-to-video' in model_endpoint or 'image-to-video' in model_endpoint) else 60
result = await asyncio.wait_for(handler.get(), timeout=timeout_seconds)
except asyncio.TimeoutError:
print(f"[ERROR] Timeout waiting for result")
raise Exception("Timeout waiting for FAL API result")
print(f"[DEBUG] Result received: {json.dumps(result, indent=2)[:500] if result else 'None'}...")
# Update request status
active_requests[request_id]['status'] = 'completed'
active_requests[request_id]['result'] = result
# Add logs to result
if active_requests[request_id]['logs']:
result['logs'] = active_requests[request_id]['logs']
print(f"[DEBUG] Added {len(active_requests[request_id]['logs'])} logs to result")
print(f"[DEBUG] Async processing completed for request {request_id}")
return result
except asyncio.CancelledError:
print(f"[WARNING] Task cancelled for request {request_id}")
active_requests[request_id]['status'] = 'error'
active_requests[request_id]['error'] = 'Task was cancelled'
raise
except Exception as e:
print(f"[ERROR] Exception in process_fal_request: {str(e)}")
import traceback
traceback.print_exc()
active_requests[request_id]['status'] = 'error'
active_requests[request_id]['error'] = str(e)
raise
def run_async_task(request_id, model_endpoint, fal_arguments, api_key):
"""Run async task with proper event loop management"""
print(f"[DEBUG run_async_task] Starting for request {request_id}")
# Create and run event loop in a way that prevents premature closure
loop = None
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Enable debug mode to get better error messages
loop.set_debug(False) # Set to True for more verbose debugging
# Run the async task in the new loop
result = loop.run_until_complete(process_fal_request(request_id, model_endpoint, fal_arguments, api_key))
print(f"[DEBUG run_async_task] Completed with result keys: {result.keys() if result else 'None'}")
# Ensure all async generators are closed properly
loop.run_until_complete(asyncio.sleep(0)) # Process any remaining callbacks
except Exception as e:
print(f"[ERROR run_async_task] Failed: {str(e)}")
import traceback
traceback.print_exc()
# Update the request status to error
if request_id in active_requests:
active_requests[request_id]['status'] = 'error'
active_requests[request_id]['error'] = str(e)
finally:
# More careful cleanup to avoid closing loop while still in use
if loop:
try:
# Give time for any final async operations to complete
if not loop.is_closed():
# Run a small delay to let any final callbacks execute
loop.run_until_complete(asyncio.sleep(0.5))
# Get all tasks that are still pending
pending = asyncio.all_tasks(loop)
# Cancel them gracefully
for task in pending:
task.cancel()
# Wait for all tasks to be cancelled, with a timeout
if pending:
loop.run_until_complete(
asyncio.wait_for(
asyncio.gather(*pending, return_exceptions=True),
timeout=5.0
)
)
except asyncio.TimeoutError:
print(f"[WARNING] Timeout during task cleanup")
except Exception as cleanup_error:
print(f"[WARNING] Error during loop cleanup: {cleanup_error}")
finally:
# Close the loop only after everything is done
try:
if not loop.is_closed():
# Stop the loop first to ensure it's not running
loop.stop()
# Give it a moment to stop
loop.run_until_complete(asyncio.sleep(0))
# Now close it
loop.close()
except Exception as close_error:
print(f"[WARNING] Error closing loop: {close_error}")
finally:
# Clear the event loop from the thread
asyncio.set_event_loop(None)
@app.route('/api/generate', methods=['POST'])
def generate():
"""Handle image generation requests (non-blocking)"""
print("[DEBUG /api/generate] Endpoint called")
try:
# Get request data
data = request.json
print(f"[DEBUG /api/generate] Request data keys: {data.keys() if data else 'None'}")
# Get model endpoint from header or default to edit
model_endpoint = request.headers.get('X-Model-Endpoint', 'fal-ai/bytedance/seedream/v4/edit')
print(f"[DEBUG /api/generate] Model endpoint: {model_endpoint}")
# Get API key from header or environment
auth_header = request.headers.get('Authorization', '')
if auth_header.startswith('Bearer '):
api_key = auth_header.replace('Bearer ', '')
elif os.environ.get('FAL_KEY'):
api_key = os.environ.get('FAL_KEY')
else:
return jsonify({'error': 'API key not provided'}), 401
# Prepare arguments for FAL API
fal_arguments = {
'prompt': data.get('prompt')
}
# Handle model-specific parameters
is_text_to_image = 'text-to-image' in model_endpoint
is_text_to_video = 'text-to-video' in model_endpoint
is_image_to_video = 'image-to-video' in model_endpoint
if is_text_to_image:
# Image generation (text-to-image)
if 'image_size' in data:
fal_arguments['image_size'] = data['image_size']
if 'num_images' in data:
fal_arguments['num_images'] = data['num_images']
elif is_text_to_video:
# Video generation (text-to-video)
for k in ['aspect_ratio', 'resolution', 'duration', 'camera_fixed']:
if k in data:
fal_arguments[k] = data[k]
elif is_image_to_video:
# Video generation (image-to-video) - single image_url
image_url = data.get('image_url')
if not image_url:
# Fallback: use first from image_urls if provided
urls = data.get('image_urls', [])
if urls:
image_url = urls[0]
if image_url:
fal_arguments['image_url'] = image_url
for k in ['aspect_ratio', 'resolution', 'duration', 'camera_fixed']:
if k in data:
fal_arguments[k] = data[k]
else:
# Image edit mode (default)
if 'image_size' in data:
fal_arguments['image_size'] = data['image_size']
if 'num_images' in data:
fal_arguments['num_images'] = data['num_images']
processed_image_urls = []
for url in data.get('image_urls', []):
if url.startswith('data:'):
# Handle base64 data URLs
processed_image_urls.append(url)
else:
# Regular URL
processed_image_urls.append(url)
fal_arguments['image_urls'] = processed_image_urls[:10] # Max 10 images
# Add max_images for edit mode
if 'max_images' in data:
fal_arguments['max_images'] = data['max_images']
# Add shared optional parameters
if 'seed' in data:
fal_arguments['seed'] = data['seed']
if 'enable_safety_checker' in data:
fal_arguments['enable_safety_checker'] = data['enable_safety_checker']
# Generate unique request ID
request_id = str(uuid.uuid4())
print(f"[DEBUG /api/generate] Generated request ID: {request_id}")
# Initialize request tracking
active_requests[request_id] = {
'status': 'submitted',
'logs': [],
'result': None
}
print(f"[DEBUG /api/generate] Request tracking initialized")
# Schedule async processing on background event loop
future = schedule_in_async_loop(process_fal_request(request_id, model_endpoint, fal_arguments, api_key))
def _bg_done_callback(fut, req_id=request_id):
try:
fut.result()
except Exception as e:
print(f"[ERROR /api/generate] Background task failed for {req_id}: {e}")
future.add_done_callback(_bg_done_callback)
print(f"[DEBUG /api/generate] Background task scheduled on shared loop")
# Return request ID immediately (non-blocking)
response_data = {
'request_id': request_id,
'status': 'submitted',
'message': 'Request submitted successfully'
}
print(f"[DEBUG /api/generate] Returning response: {response_data}")
return jsonify(response_data), 202
except Exception as e:
print(f"Error in generate endpoint: {str(e)}")
return jsonify({'error': str(e)}), 500
@app.route('/api/status/<request_id>', methods=['GET'])
def check_status(request_id):
"""Check the status of a generation request"""
print(f"[DEBUG /api/status] Checking status for request {request_id}")
if request_id not in active_requests:
print(f"[DEBUG /api/status] Request {request_id} not found in active_requests")
print(f"[DEBUG /api/status] Active request IDs: {list(active_requests.keys())}")
return jsonify({'error': 'Request not found'}), 404
request_info = active_requests[request_id]
print(f"[DEBUG /api/status] Request status: {request_info['status']}")
print(f"[DEBUG /api/status] Request has {len(request_info.get('logs', []))} logs")
response = {
'request_id': request_id,
'status': request_info['status'],
'logs': request_info.get('logs', [])
}
if request_info['status'] == 'completed':
result = request_info['result']
print(f"[DEBUG /api/status] Request completed, result keys: {result.keys() if result else 'None'}")
if result and 'images' in result:
print(f"[DEBUG /api/status] Found {len(result['images'])} images in result")
for i, img in enumerate(result['images']):
print(f"[DEBUG /api/status] Image {i+1} keys: {img.keys() if isinstance(img, dict) else type(img)}")
response['result'] = result
# Clean up completed request after retrieval
del active_requests[request_id]
print(f"[DEBUG /api/status] Cleaned up completed request {request_id}")
elif request_info['status'] == 'error':
error_msg = request_info.get('error', 'Unknown error')
print(f"[DEBUG /api/status] Request error: {error_msg}")
response['error'] = error_msg
# Clean up errored request after retrieval
del active_requests[request_id]
print(f"[DEBUG /api/status] Cleaned up errored request {request_id}")
print(f"[DEBUG /api/status] Returning response with status: {response['status']}")
return jsonify(response), 200
@app.route('/api/upload', methods=['POST'])
def upload_file():
"""Handle file uploads and return data URL"""
try:
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
# Read file and convert to base64 data URL
file_content = file.read()
file_type = file.content_type or 'application/octet-stream'
base64_content = base64.b64encode(file_content).decode('utf-8')
data_url = f"data:{file_type};base64,{base64_content}"
return jsonify({'url': data_url}), 200
except Exception as e:
print(f"Error in upload endpoint: {str(e)}")
return jsonify({'error': str(e)}), 500
async def upload_file_to_fal_async(file_path, api_key):
"""Helper function to upload a file to FAL storage"""
# Use per-request AsyncClient with explicit key to avoid global env mutation
client = fal_client.AsyncClient(key=api_key)
return await client.upload_file(file_path)
def upload_to_fal_sync(file_path, api_key):
"""Run upload coroutine on the shared background event loop"""
try:
# 5 min timeout to accommodate large uploads if needed
return run_in_async_loop(upload_file_to_fal_async(file_path, api_key), timeout=300)
except Exception as e:
print(f"[ERROR upload_to_fal_sync] Upload failed: {str(e)}")
raise
@app.route('/api/upload-to-fal', methods=['POST'])
def upload_to_fal():
"""Upload base64 image data to FAL storage and return the URL"""
try:
data = request.json
if 'image_data' not in data:
return jsonify({'error': 'No image data provided'}), 400
# Get API key from header or environment
auth_header = request.headers.get('Authorization', '')
if auth_header.startswith('Bearer '):
api_key = auth_header.replace('Bearer ', '')
elif os.environ.get('FAL_KEY'):
api_key = os.environ.get('FAL_KEY')
else:
return jsonify({'error': 'API key not provided'}), 401
image_data = data['image_data']
# If it's a base64 data URL, extract the actual base64 content
if image_data.startswith('data:'):
# Extract base64 content from data URL
header, base64_content = image_data.split(',', 1)
# Decode base64 to bytes
image_bytes = base64.b64decode(base64_content)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
tmp_file.write(image_bytes)
tmp_file_path = tmp_file.name
try:
# Use the synchronous wrapper which handles event loop properly
fal_url = upload_to_fal_sync(tmp_file_path, api_key)
print(f"[DEBUG] Uploaded to FAL: {fal_url}")
return jsonify({'url': fal_url}), 200
finally:
# Clean up temporary file
try:
os.unlink(tmp_file_path)
except:
pass
else:
# If it's already a URL, return it as-is
return jsonify({'url': image_data}), 200
except Exception as e:
print(f"Error uploading to FAL: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint for container monitoring"""
return jsonify({'status': 'healthy'}), 200
if __name__ == '__main__':
# Get port from environment or default to 7860 (Hugging Face Spaces default)
port = int(os.environ.get('PORT', 7860))
# Check if running in production (Hugging Face Spaces)
is_production = os.environ.get('SPACE_ID') is not None
# Run the application
if is_production:
app.run(host='0.0.0.0', port=port, debug=False)
else:
app.run(host='0.0.0.0', port=port, debug=True)