Devity4756 commited on
Commit
b11d52a
·
verified ·
1 Parent(s): 0a448d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -200
app.py CHANGED
@@ -1,236 +1,124 @@
1
- # spam.py (optimized backend)
2
- # Install: pip install flask flask-socketio gradio_client eventlet
3
 
4
- from flask import Flask, request, render_template
5
  from flask_socketio import SocketIO, emit
6
  from gradio_client import Client, handle_file
7
- import os
8
- import base64
9
- import logging
10
- import threading
11
- import time
12
- import re
13
  from datetime import datetime, timedelta
14
 
15
- # Set up logging
16
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
17
  logger = logging.getLogger(__name__)
18
 
 
19
  app = Flask(__name__)
20
- # Increased timeouts to prevent disconnections during long API calls
21
- socketio = SocketIO(app,
22
- cors_allowed_origins="*",
23
- ping_timeout=300, # Increased to 5 minutes
24
- ping_interval=60, # Increased to 1 minute
25
- max_http_buffer_size=10 * 1024 * 1024, # 10MB max file size
26
- async_mode='eventlet')
27
-
28
- # Replace with your ACTUAL runtime URL from "Use via API" (e.g., https://tonyassi-voice-clone.hf.space)
29
- HF_SPACE_URL = "https://tonyassi-voice-clone.hf.space" # Update this!
30
 
 
 
31
  try:
32
- logger.info(f"Loading Gradio Client for {HF_SPACE_URL}...")
33
  client = Client(HF_SPACE_URL)
34
- logger.info("Client loaded successfully!")
35
- except ValueError as e:
36
- logger.error(f"Failed to load client: {e}")
37
- print(f"ERROR: Invalid URL. Visit https://huggingface.co/spaces/tonyassi/voice-clone, click 'Use via API', and copy the base URL.")
38
- exit(1)
39
  except Exception as e:
40
- logger.error(f"Unexpected error loading client: {e}")
41
  exit(1)
42
 
43
- # Store active tasks and quota information
44
  active_tasks = {}
45
- quota_info = {
46
- 'reset_time': None,
47
- 'retry_after': None
48
- }
49
 
50
- @app.route('/status')
 
51
  def status_check():
52
- """Health check endpoint for monitoring"""
53
- return {
54
- 'status': 'ok',
55
- 'quota_reset_time': quota_info['reset_time'],
56
- 'active_tasks': len(active_tasks)
57
- }
58
-
59
- @socketio.on('connect')
60
  def handle_connect():
61
- logger.info('Client connected: %s', request.sid)
62
- # Send current quota status to newly connected client
63
- if quota_info['reset_time'] and quota_info['reset_time'] > datetime.now():
64
- time_left = quota_info['reset_time'] - datetime.now()
65
- hours, remainder = divmod(time_left.total_seconds(), 3600)
66
- minutes, seconds = divmod(remainder, 60)
67
- emit('error', {'message': f'GPU quota exceeded. Try again in {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}'})
68
- else:
69
- emit('status', {'message': 'Connected to backend'})
70
-
71
- @socketio.on('disconnect')
72
  def handle_disconnect():
73
- logger.info('Client disconnected: %s', request.sid)
74
- # Clean up any active tasks for this client
75
- if request.sid in active_tasks:
76
- del active_tasks[request.sid]
77
 
78
- @socketio.on('generate_voice')
79
  def handle_generate_voice(data):
 
80
  try:
81
- sid = request.sid
82
- logger.info('Generate voice request from: %s', sid)
83
-
84
- # Check if we're in a quota timeout period
85
- if quota_info['reset_time'] and quota_info['reset_time'] > datetime.now():
86
- time_left = quota_info['reset_time'] - datetime.now()
87
- hours, remainder = divmod(time_left.total_seconds(), 3600)
88
- minutes, seconds = divmod(remainder, 60)
89
- emit('error', {'message': f'GPU quota exceeded. Try again in {int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}'})
90
  return
91
-
92
- text = data['text']
93
- audio_base64 = data['audio'] # Base64 from frontend
94
-
95
- if not text:
96
- raise ValueError("No text provided")
97
-
98
- # Store that this client has an active task
99
- active_tasks[sid] = {
100
- 'start_time': datetime.now(),
101
- 'status': 'processing'
102
- }
103
-
104
- # Send immediate acknowledgement
105
- emit('status', {'message': 'Processing your request...'})
106
-
107
- # Run the processing in a thread to avoid blocking
108
- thread = threading.Thread(target=process_voice, args=(sid, text, audio_base64))
109
- thread.daemon = True
110
- thread.start()
111
-
112
  except Exception as e:
113
- logger.error(f"Error in generate_voice handler: {e}")
114
- emit('error', {'message': f"Failed to process request: {str(e)}"})
115
- # Clean up task tracking
116
  if sid in active_tasks:
117
  del active_tasks[sid]
118
 
119
- def process_voice(sid, text, audio_base64, retry_count=0):
120
- """Process voice generation in a separate thread with retry logic"""
121
- max_retries = 2 # Maximum number of retry attempts
122
-
123
  try:
124
- # Decode base64 to temp file (skip prefix like 'data:audio/wav;base64,')
125
- if audio_base64.startswith('data:'):
126
- audio_base64 = audio_base64.split(',')[1]
127
-
128
- audio_data = base64.b64decode(audio_base64)
129
- temp_audio_path = f'/tmp/temp_reference_{sid}.wav' # Unique filename for each client
130
- with open(temp_audio_path, 'wb') as f:
131
- f.write(audio_data)
132
-
133
- logger.info("Calling HF Space for client: %s (attempt %d)", sid, retry_count + 1)
134
-
135
- # Send progress update
136
- with app.app_context():
137
- socketio.emit('status', {'message': f'Calling Hugging Face API... (Attempt {retry_count + 1})'}, room=sid)
138
-
139
- # Call the API with timeout handling
140
- try:
141
- # Call the API (api_name="/predict" matches Gradio's default for your Interface)
142
- result = client.predict(
143
- text,
144
- handle_file(temp_audio_path),
145
- api_name="/predict"
146
- )
147
- except Exception as api_error:
148
- # Check if it's a timeout error and we should retry
149
- if "timeout" in str(api_error).lower() and retry_count < max_retries:
150
- logger.warning("API timeout for client %s, retrying... (attempt %d)", sid, retry_count + 1)
151
- with app.app_context():
152
- socketio.emit('status', {'message': f'Timeout occurred, retrying... (Attempt {retry_count + 2})'}, room=sid)
153
- # Wait a bit before retrying
154
- time.sleep(5)
155
- # Retry the request
156
- process_voice(sid, text, audio_base64, retry_count + 1)
157
- return
158
- else:
159
- raise api_error # Re-raise if not a timeout or max retries exceeded
160
-
161
- # Send progress update
162
- with app.app_context():
163
- socketio.emit('status', {'message': 'Processing audio response...'}, room=sid)
164
-
165
- # Read and encode output to base64
166
- with open(result, 'rb') as f:
167
  output_audio = f.read()
168
- output_base64 = base64.b64encode(output_audio).decode('utf-8')
169
-
170
- # Cleanup
171
- try:
172
- os.remove(temp_audio_path)
173
- if os.path.exists(result):
174
- os.remove(result)
175
- except Exception as cleanup_error:
176
- logger.warning("Cleanup error for client %s: %s", sid, cleanup_error)
177
-
178
- logger.info("Generation complete for client: %s", sid)
179
-
180
- # Send results back to the specific client
181
- with app.app_context():
182
- socketio.emit('voice_generated', {'audio': f'data:audio/wav;base64,{output_base64}'}, room=sid)
183
- socketio.emit('status', {'message': 'Generation complete'}, room=sid)
184
-
185
  except Exception as e:
186
- logger.error(f"Error in process_voice for client {sid}: {e}")
187
-
188
- # Check if this is a GPU quota error
189
- error_msg = str(e)
190
- if "quota" in error_msg.lower():
191
- # Try to parse the time from the error message
192
- time_match = re.search(r'Try again in (\d+):(\d+):(\d+)', error_msg)
193
- if time_match:
194
- hours, minutes, seconds = map(int, time_match.groups())
195
- reset_time = datetime.now() + timedelta(hours=hours, minutes=minutes, seconds=seconds)
196
- quota_info['reset_time'] = reset_time
197
- quota_info['retry_after'] = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
198
-
199
- logger.warning("GPU quota exceeded. Resets at: %s", reset_time)
200
- error_msg = f"GPU quota exceeded. Try again in {hours:02d}:{minutes:02d}:{seconds:02d}"
201
-
202
- with app.app_context():
203
- socketio.emit('error', {'message': f"Generation failed: {error_msg}"}, room=sid)
204
  finally:
205
- # Clean up task tracking
 
 
206
  if sid in active_tasks:
207
  del active_tasks[sid]
208
 
 
209
  def cleanup_old_files():
210
- """Clean up temporary files older than 1 hour"""
211
- try:
212
  now = time.time()
213
- for filename in os.listdir('/tmp'):
214
- if filename.startswith('temp_reference_') and filename.endswith('.wav'):
215
- filepath = os.path.join('/tmp', filename)
216
- if os.path.isfile(filepath) and now - os.path.getctime(filepath) > 3600:
217
- os.remove(filepath)
218
- logger.info("Cleaned up old file: %s", filename)
219
- except Exception as e:
220
- logger.warning("Error during cleanup: %s", e)
221
-
222
- # Start a background thread for periodic cleanup
223
- def start_cleanup_thread():
224
- def cleanup_loop():
225
- while True:
226
- time.sleep(3600) # Clean up every hour
227
- cleanup_old_files()
228
-
229
- thread = threading.Thread(target=cleanup_loop)
230
- thread.daemon = True
231
- thread.start()
232
-
233
- if __name__ == '__main__':
234
- logger.info("Starting backend with improved timeout and quota handling...")
235
- start_cleanup_thread()
236
- socketio.run(app, host='0.0.0.0', port=5000, debug=True, allow_unsafe_werkzeug=True)
 
1
+ # spam_space_backend.py
2
+ # Install: pip install flask flask-socketio gradio_client
3
 
4
+ from flask import Flask, request
5
  from flask_socketio import SocketIO, emit
6
  from gradio_client import Client, handle_file
7
+ import os, base64, threading, time, logging
 
 
 
 
 
8
  from datetime import datetime, timedelta
9
 
10
+ # ----------------- Logging -----------------
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
  logger = logging.getLogger(__name__)
13
 
14
+ # ----------------- Flask + SocketIO -----------------
15
  app = Flask(__name__)
16
+ # Use 'threading' mode for maximum compatibility on Spaces
17
+ socketio = SocketIO(app, cors_allowed_origins="*", async_mode='threading')
 
 
 
 
 
 
 
 
18
 
19
+ # ----------------- HF Space -----------------
20
+ HF_SPACE_URL = "https://tonyassi-voice-clone.hf.space" # Replace with your Space API URL
21
  try:
 
22
  client = Client(HF_SPACE_URL)
23
+ logger.info("Gradio Client loaded successfully!")
 
 
 
 
24
  except Exception as e:
25
+ logger.error(f"Failed to load client: {e}")
26
  exit(1)
27
 
28
+ # ----------------- Task & Quota Tracking -----------------
29
  active_tasks = {}
30
+ quota_info = {"reset_time": None, "retry_after": None}
 
 
 
31
 
32
+ # ----------------- Routes -----------------
33
+ @app.route("/status")
34
  def status_check():
35
+ return {"status": "ok", "active_tasks": len(active_tasks), "quota_reset_time": quota_info["reset_time"]}
36
+
37
+ # ----------------- SocketIO Events -----------------
38
+ @socketio.on("connect")
 
 
 
 
39
  def handle_connect():
40
+ sid = request.sid
41
+ logger.info(f"Client connected: {sid}")
42
+ emit("status", {"message": "Connected to backend"})
43
+
44
+ @socketio.on("disconnect")
 
 
 
 
 
 
45
  def handle_disconnect():
46
+ sid = request.sid
47
+ logger.info(f"Client disconnected: {sid}")
48
+ if sid in active_tasks:
49
+ del active_tasks[sid]
50
 
51
+ @socketio.on("generate_voice")
52
  def handle_generate_voice(data):
53
+ sid = request.sid
54
  try:
55
+ text = data.get("text")
56
+ audio_base64 = data.get("audio")
57
+
58
+ if not text or not audio_base64:
59
+ emit("error", {"message": "Text or audio missing"})
 
 
 
 
60
  return
61
+
62
+ # Track active task
63
+ active_tasks[sid] = {"start_time": datetime.now(), "status": "processing"}
64
+ emit("status", {"message": "Processing request..."})
65
+
66
+ # Process in background thread
67
+ threading.Thread(target=process_voice, args=(sid, text, audio_base64), daemon=True).start()
68
+
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  except Exception as e:
70
+ logger.error(f"Error in generate_voice: {e}")
71
+ emit("error", {"message": f"Failed to process request: {str(e)}"})
 
72
  if sid in active_tasks:
73
  del active_tasks[sid]
74
 
75
+ # ----------------- Voice Processing -----------------
76
+ def process_voice(sid, text, audio_base64):
77
+ temp_audio_path = f"/tmp/temp_reference_{sid}.wav"
 
78
  try:
79
+ # Decode audio
80
+ if audio_base64.startswith("data:"):
81
+ audio_base64 = audio_base64.split(",")[1]
82
+ with open(temp_audio_path, "wb") as f:
83
+ f.write(base64.b64decode(audio_base64))
84
+
85
+ # Call HF Space API
86
+ socketio.emit("status", {"message": "Calling HF Space API..."}, room=sid)
87
+ result_path = client.predict(text, handle_file(temp_audio_path), api_name="/predict")
88
+
89
+ # Read result and send back
90
+ with open(result_path, "rb") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  output_audio = f.read()
92
+ output_base64 = base64.b64encode(output_audio).decode("utf-8")
93
+
94
+ socketio.emit("voice_generated", {"audio": f"data:audio/wav;base64,{output_base64}"}, room=sid)
95
+ socketio.emit("status", {"message": "Generation complete"}, room=sid)
96
+
 
 
 
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
+ logger.error(f"Error in process_voice: {e}")
99
+ socketio.emit("error", {"message": f"Generation failed: {str(e)}"}, room=sid)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  finally:
101
+ # Cleanup
102
+ if os.path.exists(temp_audio_path):
103
+ os.remove(temp_audio_path)
104
  if sid in active_tasks:
105
  del active_tasks[sid]
106
 
107
+ # ----------------- Cleanup Thread -----------------
108
  def cleanup_old_files():
109
+ while True:
 
110
  now = time.time()
111
+ for f in os.listdir("/tmp"):
112
+ if f.startswith("temp_reference_") and f.endswith(".wav"):
113
+ path = os.path.join("/tmp", f)
114
+ if now - os.path.getctime(path) > 3600: # 1 hour
115
+ os.remove(path)
116
+ logger.info(f"Removed old file: {f}")
117
+ time.sleep(3600)
118
+
119
+ threading.Thread(target=cleanup_old_files, daemon=True).start()
120
+
121
+ # ----------------- Main -----------------
122
+ if __name__ == "__main__":
123
+ logger.info("Starting backend...")
124
+ socketio.run(app, host="0.0.0.0", port=5000, debug=True)