File size: 3,861 Bytes
0f83929
30ef6ad
b212dfe
0f83929
 
 
 
 
7ab7557
b212dfe
 
7ab7557
0f83929
2eced7f
 
 
 
 
 
 
 
 
 
 
 
 
 
0f83929
7ab7557
 
 
b212dfe
0f83929
 
 
 
 
 
 
2eced7f
0f83929
2eced7f
30ef6ad
 
2eced7f
 
8278018
2eced7f
 
 
 
 
 
 
30ef6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98d7910
 
b212dfe
4c48d2c
a3bb8ec
f0c397d
 
4c48d2c
98d7910
30ef6ad
 
 
ba20cfe
 
 
 
 
 
 
 
 
 
 
47d8bee
0f83929
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import random
import threading
import psutil
import fastapi
import gradio as gr
import uvicorn

from viser_proxy_manager import ViserProxyManager
from vis_st4rtrack import visualize_st4rtrack, load_trajectory_data, log_memory_usage

# Global cache for loaded data
global_data_cache = None

def check_ram_usage(threshold_percent=90):
    """Check if RAM usage is above the threshold.
    
    Args:
        threshold_percent: Maximum RAM usage percentage allowed
        
    Returns:
        bool: True if RAM usage is below threshold, False otherwise
    """
    ram_percent = psutil.virtual_memory().percent
    print(f"Current RAM usage: {ram_percent}%")
    return ram_percent < threshold_percent


def main() -> None:
    # Load data once at startup using the function from vis_st4rtrack.py
    global global_data_cache
    global_data_cache = load_trajectory_data(use_float16=True, max_frames=32)
    
    app = fastapi.FastAPI()
    viser_manager = ViserProxyManager(app)

    # Create a Gradio interface with title, iframe, and buttons
    with gr.Blocks(title="Viser Viewer") as demo:
        # Add the iframe with a border
        iframe_html = gr.HTML("")
        status_text = gr.Markdown("")  # Add status text component

        @demo.load(outputs=[iframe_html, status_text])
        def start_server(request: gr.Request):
            assert request.session_hash is not None
            
            # Check RAM usage before starting visualization
            if not check_ram_usage(threshold_percent=100):
                return """
                <div style="text-align: center; padding: 20px; background-color: #ffeeee; border-radius: 5px;">
                    <h2>⚠️ Server is currently under high load</h2>
                    <p>Please try again later when resources are available.</p>
                </div>
                """, "**System Status:** High memory usage detected. Visualization not loaded to prevent server overload."
            
            viser_manager.start_server(request.session_hash)

            # Use the request's base URL if available
            host = request.headers["host"]

            # Determine protocol (use HTTPS for HuggingFace Spaces or other secure environments)
            protocol = (
                "https"
                if request.headers.get("x-forwarded-proto") == "https"
                else "http"
            )
            
            # Add visualization in a separate thread
            server = viser_manager.get_server(request.session_hash)
            threading.Thread(
                target=visualize_st4rtrack,
                kwargs={
                    "server": server,
                    "use_float16": True,
                    "preloaded_data": global_data_cache,  # Pass the preloaded data
                    "color_code": "jet",
                    "blue_rgb": (0.0, 0.149, 0.463),  # #002676
                    "red_rgb": (0.769, 0.510, 0.055),   # #FDB515
                    "blend_ratio": 0.7
                },
                daemon=True
            ).start()

            return f"""
            <iframe 
                src="{protocol}://{host}/viser/{request.session_hash}/" 
                width="100%" 
                height="500px" 
                frameborder="0" 
                style="display: block;"
                allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
                loading="lazy"
            ></iframe>
            """, "**System Status:** Visualization loaded successfully."

        @demo.unload
        def stop(request: gr.Request):
            assert request.session_hash is not None
            viser_manager.stop_server(request.session_hash)

    gr.mount_gradio_app(app, demo, "/")
    uvicorn.run(app, host="0.0.0.0", port=7860)


if __name__ == "__main__":
    main()