|
|
""" |
|
|
Error handling utilities for VidGen |
|
|
""" |
|
|
|
|
|
import streamlit as st |
|
|
import time |
|
|
import functools |
|
|
from typing import Any, Callable, Optional |
|
|
|
|
|
def retry_on_error(max_retries: int = 3, delay: float = 1.0, backoff: float = 2.0): |
|
|
"""Decorator to retry functions on specific errors""" |
|
|
def decorator(func: Callable) -> Callable: |
|
|
@functools.wraps(func) |
|
|
def wrapper(*args, **kwargs) -> Any: |
|
|
last_exception = None |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
except Exception as e: |
|
|
last_exception = e |
|
|
error_msg = str(e).lower() |
|
|
|
|
|
|
|
|
retryable_errors = [ |
|
|
'bodystreamBuffer was aborted', |
|
|
'connection error', |
|
|
'timeout', |
|
|
'network error', |
|
|
'temporary failure', |
|
|
'rate limit' |
|
|
] |
|
|
|
|
|
is_retryable = any(err in error_msg for err in retryable_errors) |
|
|
|
|
|
if not is_retryable or attempt == max_retries - 1: |
|
|
raise e |
|
|
|
|
|
wait_time = delay * (backoff ** attempt) |
|
|
st.warning(f"⚠️ Attempt {attempt + 1} failed: {e}. Retrying in {wait_time:.1f}s...") |
|
|
time.sleep(wait_time) |
|
|
|
|
|
raise last_exception |
|
|
return wrapper |
|
|
return decorator |
|
|
|
|
|
def handle_stream_error(func: Callable) -> Callable: |
|
|
"""Handle stream buffer errors specifically""" |
|
|
@functools.wraps(func) |
|
|
def wrapper(*args, **kwargs): |
|
|
try: |
|
|
return func(*args, **kwargs) |
|
|
except Exception as e: |
|
|
error_msg = str(e).lower() |
|
|
if 'bodystreamBuffer' in error_msg or 'aborted' in error_msg: |
|
|
st.error("🔄 Connection interrupted. Please try again.") |
|
|
st.info("💡 Tip: Try refreshing the page if the issue persists.") |
|
|
return None |
|
|
else: |
|
|
raise e |
|
|
return wrapper |
|
|
|
|
|
class ProgressTracker: |
|
|
"""Track progress with better error handling""" |
|
|
|
|
|
def __init__(self, total_jobs: int): |
|
|
self.total_jobs = total_jobs |
|
|
self.current_job = 0 |
|
|
self.progress_bar = None |
|
|
self.status_text = None |
|
|
|
|
|
def start(self): |
|
|
"""Initialize progress tracking""" |
|
|
self.progress_bar = st.progress(0) |
|
|
self.status_text = st.empty() |
|
|
|
|
|
def update(self, job_idx: int, status: str = "Processing..."): |
|
|
"""Update progress""" |
|
|
if self.progress_bar and self.status_text: |
|
|
progress = (job_idx + 1) / self.total_jobs |
|
|
self.progress_bar.progress(progress) |
|
|
self.status_text.text(f"{status} ({job_idx + 1}/{self.total_jobs})") |
|
|
|
|
|
def complete(self): |
|
|
"""Mark as complete""" |
|
|
if self.progress_bar and self.status_text: |
|
|
self.progress_bar.progress(1.0) |
|
|
self.status_text.text("✅ All jobs completed!") |