File size: 2,825 Bytes
20651a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import requests

# Configuration for MCP server
MCP_SERVER_URL = os.getenv("MCP_SERVER_URL", "http://localhost:8000")
HF_TOKEN = os.getenv("HF_TOKEN", "")


def call_api_endpoint(endpoint: str, data: dict = None, method: str = "POST"):
    """Call an API endpoint on the MCP server."""
    url = f"{MCP_SERVER_URL}/api/{endpoint}"

    # Include HF token for authentication to private Spaces
    headers = {}
    if HF_TOKEN:
        headers["Authorization"] = f"Bearer {HF_TOKEN}"

    try:
        if method == "GET":
            response = requests.get(url, headers=headers)
        else:
            response = requests.post(url, json=data or {}, headers=headers)
        response.raise_for_status()
        return response.json()
    except Exception as e:
        print(f"Error calling API {endpoint}: {e}")
        return {"error": str(e)}


def get_catalog():
    """Fetch the dataset catalog from MCP server."""
    result = call_api_endpoint("catalog", method="GET")
    if isinstance(result, list):
        return result
    if "error" in result:
        print(f"Error fetching catalog: {result['error']}")
    return []


def get_user_subscriptions(hf_user: str, hf_token: str = None):
    """Fetch subscriptions for a specific user. Requires HF token for authentication."""
    if not hf_user:
        return []
    if not hf_token:
        print("Warning: hf_token required for user_subscriptions")
        return []
    result = call_api_endpoint("user_subscriptions", {
        "hf_user": hf_user,
        "hf_token": hf_token
    })
    if isinstance(result, list):
        return result
    if "error" in result:
        print(f"Error fetching subscriptions: {result['error']}")
    return []


def subscribe_free(dataset_id: str, hf_user: str, hf_token: str = None):
    """Subscribe to a free dataset."""
    return call_api_endpoint("subscribe_free", {
        "dataset_id": dataset_id,
        "hf_token": hf_token or "",
        "hf_user": hf_user
    })


def create_checkout_session(dataset_id: str, hf_user: str, hf_token: str = None):
    """Create a Stripe checkout session for a paid dataset."""
    return call_api_endpoint("create_checkout_session", {
        "dataset_id": dataset_id,
        "hf_token": hf_token or "",
        "hf_user": hf_user
    })


# Legacy wrapper for backwards compatibility
def call_mcp_tool(tool_name: str, arguments: dict):
    """Legacy wrapper. Use specific functions above instead."""
    if tool_name == "subscribe_free":
        return call_api_endpoint("subscribe_free", arguments)
    elif tool_name == "create_checkout_session":
        return call_api_endpoint("create_checkout_session", arguments)
    elif tool_name == "get_dataset_catalog":
        return get_catalog()

    print(f"Tool {tool_name} not supported via API.")
    return None