File size: 5,258 Bytes
479ab0f
 
 
 
 
 
 
b3f4e74
479ab0f
 
86040e9
 
 
 
479ab0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9475b8a
479ab0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f4e74
 
 
 
7651b14
 
 
 
 
 
 
 
 
 
479ab0f
7651b14
479ab0f
7651b14
 
 
 
479ab0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9475b8a
479ab0f
 
 
b3f4e74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
from gradio.oauth import OAuthToken
from gradio_workflowcanvas import WorkflowCanvas
from gradio_client import Client, handle_file
from typing import Optional
import json
import os
import tempfile


def get_token(data=None, token: Optional[OAuthToken] = None) -> str:
    if token is None:
        return ""
    return token.token


def classify_error(e: Exception) -> dict:
    """Classify an error into a type and suggestion for the frontend."""
    title = getattr(e, 'title', None) or ""
    message = getattr(e, 'message', None) or str(e)
    full = f"{title} {message}".lower()

    if "zerogpu" in full or "gpu" in full and "worker" in full:
        return {"error_type": "gpu", "suggestion": "GPU unavailable — try again or log in with your HF account"}
    if "quota" in full or "rate limit" in full:
        return {"error_type": "quota", "suggestion": "GPU quota exceeded — log in with your HF account for more compute"}
    if "sleeping" in full or "paused" in full:
        return {"error_type": "sleeping", "suggestion": "Space is sleeping or paused — try again in a minute"}
    if "not found" in full or "404" in full or "repository not found" in full:
        return {"error_type": "not_found", "suggestion": "Space not found — it may have been deleted or renamed"}
    if "build_error" in full or "build error" in full:
        return {"error_type": "build_error", "suggestion": "Space has a build error — contact the Space owner"}
    if "timed out" in full or "timeout" in full or "connection" in full:
        return {"error_type": "connection", "suggestion": "Could not connect to the Space — it may be down"}

    return {"error_type": "unknown", "suggestion": ""}


def call_space(data, token: Optional[OAuthToken] = None) -> str:
    try:
        space_id = data[0]
        endpoint = data[1] if len(data) > 1 else None
        args_json = data[2] if len(data) > 2 else "[]"
        manual_token = data[3] if len(data) > 3 else None

        # Use manual token if provided (local dev), otherwise OAuth token (HF Spaces)
        hf_token = manual_token or (token.token if token else None)
        client = Client(space_id, token=hf_token)
        args = json.loads(args_json)

        if not endpoint or endpoint == "/predict":
            api_info = client.view_api(return_format="dict")
            named = list(api_info.get("named_endpoints", {}).keys())
            if endpoint and endpoint in named:
                pass
            elif named:
                endpoint = named[0]
            else:
                endpoint = "/predict"

        processed = []
        for arg in args:
            if arg is None:
                processed.append(None)
            elif isinstance(arg, dict) and ("url" in arg or "path" in arg):
                url = arg.get("url") or arg.get("path", "")
                if url:
                    processed.append(handle_file(url))
                else:
                    processed.append(None)
            else:
                processed.append(arg)

        # Strip trailing None args so the Space uses its own defaults.
        # view_api(return_format="dict") returns default=None even when defaults
        # exist, so we'd otherwise pass null for optional args and get
        # "No value provided for required argument".
        while processed and processed[-1] is None:
            processed.pop()

        result = client.predict(*processed, api_name=endpoint)

        if not isinstance(result, (list, tuple)):
            result = [result]
        else:
            result = list(result)

        def make_file_url(filepath):
            """Convert a local file path to a browser-accessible /gradio_api/file= URL."""
            return {"path": filepath, "url": f"/gradio_api/file={filepath}", "is_file": True}

        def process_item(item):
            """Convert result item to a browser-accessible format."""
            if isinstance(item, dict):
                # Prefer local path (downloaded by gradio_client) over remote URL
                path = item.get("path") or item.get("value")
                if isinstance(path, str) and os.path.exists(path):
                    return make_file_url(path)
                if "url" in item:
                    return item
                return item
            elif isinstance(item, str) and os.path.exists(item):
                return make_file_url(item)
            elif isinstance(item, (list, tuple)):
                return [process_item(s) for s in item]
            return item

        output = [process_item(item) for item in result]

        return json.dumps(output)
    except Exception as e:
        title = getattr(e, 'title', None)
        message = getattr(e, 'message', None) or str(e)
        classified = classify_error(e)
        error_info = {
            "error": message,
            **classified,
        }
        if title:
            error_info["title"] = title
        return json.dumps(error_info)


with gr.Blocks() as demo:
    gr.LoginButton(visible=False)
    canvas = WorkflowCanvas(server_functions=[get_token, call_space])

if __name__ == "__main__":
    demo.launch(css=".toast-wrap { display: none !important; }", allowed_paths=[tempfile.gettempdir()])