bibibi12345 commited on
Commit
9cbf868
·
1 Parent(s): c5f0b8c

aync change

Browse files
Files changed (2) hide show
  1. app.py +109 -23
  2. static/script.js +47 -5
app.py CHANGED
@@ -6,6 +6,7 @@ A Flask-based web interface for image editing using ByteDance's SeedDream model
6
  import os
7
  import json
8
  import requests
 
9
  from flask import Flask, render_template, request, jsonify, send_from_directory
10
  from flask_cors import CORS
11
  import fal_client
@@ -13,6 +14,9 @@ from werkzeug.utils import secure_filename
13
  import base64
14
  import tempfile
15
  from pathlib import Path
 
 
 
16
 
17
  app = Flask(__name__)
18
  CORS(app)
@@ -25,6 +29,12 @@ app.config['UPLOAD_FOLDER'] = tempfile.gettempdir()
25
  Path("static").mkdir(exist_ok=True)
26
  Path("templates").mkdir(exist_ok=True)
27
 
 
 
 
 
 
 
28
  @app.route('/')
29
  def index():
30
  """Serve the main HTML interface"""
@@ -35,9 +45,62 @@ def serve_static(filename):
35
  """Serve static files (CSS, JS)"""
36
  return send_from_directory('static', filename)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @app.route('/api/generate', methods=['POST'])
39
  def generate():
40
- """Handle image generation requests"""
41
  try:
42
  # Get request data
43
  data = request.json
@@ -49,9 +112,9 @@ def generate():
49
  auth_header = request.headers.get('Authorization', '')
50
  if auth_header.startswith('Bearer '):
51
  api_key = auth_header.replace('Bearer ', '')
52
- # Temporarily set the API key for this request
53
- os.environ['FAL_KEY'] = api_key
54
- elif not os.environ.get('FAL_KEY'):
55
  return jsonify({'error': 'API key not provided'}), 401
56
 
57
  # Prepare arguments for FAL API
@@ -88,33 +151,56 @@ def generate():
88
  if 'enable_safety_checker' in data:
89
  fal_arguments['enable_safety_checker'] = data['enable_safety_checker']
90
 
91
- # Create a logs collector
92
- logs = []
93
 
94
- def on_queue_update(update):
95
- """Handle queue updates and collect logs"""
96
- if isinstance(update, fal_client.InProgress):
97
- for log in update.logs:
98
- logs.append(log.get("message", ""))
99
-
100
- # Call FAL API with subscribe (blocking call)
101
- result = fal_client.subscribe(
102
- model_endpoint,
103
- arguments=fal_arguments,
104
- with_logs=True,
105
- on_queue_update=on_queue_update
106
- )
107
 
108
- # Add logs to the response
109
- if logs:
110
- result['logs'] = logs
111
 
112
- return jsonify(result), 200
 
 
 
 
 
113
 
114
  except Exception as e:
115
  print(f"Error in generate endpoint: {str(e)}")
116
  return jsonify({'error': str(e)}), 500
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  @app.route('/api/upload', methods=['POST'])
119
  def upload_file():
120
  """Handle file uploads and return data URL"""
 
6
  import os
7
  import json
8
  import requests
9
+ import asyncio
10
  from flask import Flask, render_template, request, jsonify, send_from_directory
11
  from flask_cors import CORS
12
  import fal_client
 
14
  import base64
15
  import tempfile
16
  from pathlib import Path
17
+ import uuid
18
+ from threading import Thread
19
+ from concurrent.futures import ThreadPoolExecutor
20
 
21
  app = Flask(__name__)
22
  CORS(app)
 
29
  Path("static").mkdir(exist_ok=True)
30
  Path("templates").mkdir(exist_ok=True)
31
 
32
+ # Store active request handlers
33
+ active_requests = {}
34
+
35
+ # Thread pool for async operations
36
+ executor = ThreadPoolExecutor(max_workers=10)
37
+
38
  @app.route('/')
39
  def index():
40
  """Serve the main HTML interface"""
 
45
  """Serve static files (CSS, JS)"""
46
  return send_from_directory('static', filename)
47
 
48
+ async def process_fal_request(request_id, model_endpoint, fal_arguments, api_key):
49
+ """Process FAL API request asynchronously"""
50
+ try:
51
+ # Set API key for this request
52
+ os.environ['FAL_KEY'] = api_key
53
+
54
+ # Submit the request asynchronously
55
+ handler = await fal_client.submit_async(
56
+ model_endpoint,
57
+ arguments=fal_arguments,
58
+ )
59
+
60
+ # Store handler info
61
+ active_requests[request_id] = {
62
+ 'handler': handler,
63
+ 'status': 'processing',
64
+ 'logs': [],
65
+ 'result': None
66
+ }
67
+
68
+ # Collect logs asynchronously
69
+ async for event in handler.iter_events(with_logs=True):
70
+ if hasattr(event, 'logs') and event.logs:
71
+ for log in event.logs:
72
+ active_requests[request_id]['logs'].append(log.get("message", ""))
73
+
74
+ # Get the final result
75
+ result = await handler.get()
76
+
77
+ # Update request status
78
+ active_requests[request_id]['status'] = 'completed'
79
+ active_requests[request_id]['result'] = result
80
+
81
+ # Add logs to result
82
+ if active_requests[request_id]['logs']:
83
+ result['logs'] = active_requests[request_id]['logs']
84
+
85
+ return result
86
+
87
+ except Exception as e:
88
+ active_requests[request_id]['status'] = 'error'
89
+ active_requests[request_id]['error'] = str(e)
90
+ raise
91
+
92
+ def run_async_task(request_id, model_endpoint, fal_arguments, api_key):
93
+ """Run async task in a new event loop"""
94
+ loop = asyncio.new_event_loop()
95
+ asyncio.set_event_loop(loop)
96
+ try:
97
+ loop.run_until_complete(process_fal_request(request_id, model_endpoint, fal_arguments, api_key))
98
+ finally:
99
+ loop.close()
100
+
101
  @app.route('/api/generate', methods=['POST'])
102
  def generate():
103
+ """Handle image generation requests (non-blocking)"""
104
  try:
105
  # Get request data
106
  data = request.json
 
112
  auth_header = request.headers.get('Authorization', '')
113
  if auth_header.startswith('Bearer '):
114
  api_key = auth_header.replace('Bearer ', '')
115
+ elif os.environ.get('FAL_KEY'):
116
+ api_key = os.environ.get('FAL_KEY')
117
+ else:
118
  return jsonify({'error': 'API key not provided'}), 401
119
 
120
  # Prepare arguments for FAL API
 
151
  if 'enable_safety_checker' in data:
152
  fal_arguments['enable_safety_checker'] = data['enable_safety_checker']
153
 
154
+ # Generate unique request ID
155
+ request_id = str(uuid.uuid4())
156
 
157
+ # Initialize request tracking
158
+ active_requests[request_id] = {
159
+ 'status': 'submitted',
160
+ 'logs': [],
161
+ 'result': None
162
+ }
 
 
 
 
 
 
 
163
 
164
+ # Start async processing in background thread
165
+ thread = Thread(target=run_async_task, args=(request_id, model_endpoint, fal_arguments, api_key))
166
+ thread.start()
167
 
168
+ # Return request ID immediately (non-blocking)
169
+ return jsonify({
170
+ 'request_id': request_id,
171
+ 'status': 'submitted',
172
+ 'message': 'Request submitted successfully'
173
+ }), 202
174
 
175
  except Exception as e:
176
  print(f"Error in generate endpoint: {str(e)}")
177
  return jsonify({'error': str(e)}), 500
178
 
179
+ @app.route('/api/status/<request_id>', methods=['GET'])
180
+ def check_status(request_id):
181
+ """Check the status of a generation request"""
182
+ if request_id not in active_requests:
183
+ return jsonify({'error': 'Request not found'}), 404
184
+
185
+ request_info = active_requests[request_id]
186
+
187
+ response = {
188
+ 'request_id': request_id,
189
+ 'status': request_info['status'],
190
+ 'logs': request_info.get('logs', [])
191
+ }
192
+
193
+ if request_info['status'] == 'completed':
194
+ response['result'] = request_info['result']
195
+ # Clean up completed request after retrieval
196
+ del active_requests[request_id]
197
+ elif request_info['status'] == 'error':
198
+ response['error'] = request_info.get('error', 'Unknown error')
199
+ # Clean up errored request after retrieval
200
+ del active_requests[request_id]
201
+
202
+ return jsonify(response), 200
203
+
204
  @app.route('/api/upload', methods=['POST'])
205
  def upload_file():
206
  """Handle file uploads and return data URL"""
static/script.js CHANGED
@@ -341,9 +341,10 @@ function getAPIKey() {
341
  return apiKey || localStorage.getItem('fal_api_key');
342
  }
343
 
344
- // Call FAL API (proxy through backend)
345
  async function callFalAPI(apiKey, requestData, model) {
346
- const response = await fetch('/api/generate', {
 
347
  method: 'POST',
348
  headers: {
349
  'Content-Type': 'application/json',
@@ -353,12 +354,53 @@ async function callFalAPI(apiKey, requestData, model) {
353
  body: JSON.stringify(requestData)
354
  });
355
 
356
- if (!response.ok) {
357
- const error = await response.text();
358
  throw new Error(error || 'API request failed');
359
  }
360
 
361
- return await response.json();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  }
363
 
364
  // Display results
 
341
  return apiKey || localStorage.getItem('fal_api_key');
342
  }
343
 
344
+ // Call FAL API (proxy through backend) - Non-blocking version
345
  async function callFalAPI(apiKey, requestData, model) {
346
+ // Submit the request (non-blocking)
347
+ const submitResponse = await fetch('/api/generate', {
348
  method: 'POST',
349
  headers: {
350
  'Content-Type': 'application/json',
 
354
  body: JSON.stringify(requestData)
355
  });
356
 
357
+ if (!submitResponse.ok) {
358
+ const error = await submitResponse.text();
359
  throw new Error(error || 'API request failed');
360
  }
361
 
362
+ const { request_id } = await submitResponse.json();
363
+ addLog(`Request submitted with ID: ${request_id}`);
364
+
365
+ // Poll for results
366
+ let attempts = 0;
367
+ const maxAttempts = 120; // 2 minutes with 1-second intervals
368
+ const pollInterval = 1000; // 1 second
369
+
370
+ while (attempts < maxAttempts) {
371
+ await new Promise(resolve => setTimeout(resolve, pollInterval));
372
+
373
+ const statusResponse = await fetch(`/api/status/${request_id}`);
374
+ if (!statusResponse.ok) {
375
+ throw new Error('Failed to check request status');
376
+ }
377
+
378
+ const statusData = await statusResponse.json();
379
+
380
+ // Add any new logs
381
+ if (statusData.logs && statusData.logs.length > 0) {
382
+ statusData.logs.forEach(log => {
383
+ if (log && !log.includes('Request submitted')) {
384
+ addLog(log);
385
+ }
386
+ });
387
+ }
388
+
389
+ if (statusData.status === 'completed') {
390
+ return statusData.result;
391
+ } else if (statusData.status === 'error') {
392
+ throw new Error(statusData.error || 'Generation failed');
393
+ }
394
+
395
+ attempts++;
396
+
397
+ // Update progress indicator
398
+ if (attempts % 5 === 0) {
399
+ addLog(`Processing... (${attempts}s elapsed)`);
400
+ }
401
+ }
402
+
403
+ throw new Error('Request timed out after 2 minutes');
404
  }
405
 
406
  // Display results