File size: 3,927 Bytes
ec0af28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import requests
import config

class SmartRouter:
    def __init__(self):
        # Tracks what is currently loaded in the backend (Best Guess)
        self.lane_state = {
            "primary": None,    # URL: ...-generate-primary.modal.run
            "secondary": None   # URL: ...-generate-secondary.modal.run
        }

    def get_routing_plan(self, model_left_id, model_right_id):
        """

        Decides which model goes to which lane to minimize cold starts.

        Returns: (lane_for_left_model, lane_for_right_model)

        """
        primary_model = self.lane_state["primary"]
        secondary_model = self.lane_state["secondary"]

        # Score: 0 = Cache Hit (Good), 1 = Cache Miss (Bad)
        
        # Option A: Straight (Left -> Primary, Right -> Secondary)
        cost_straight = (0 if primary_model == model_left_id else 1) + \
                        (0 if secondary_model == model_right_id else 1)

        # Option B: Swapped (Left -> Secondary, Right -> Primary)
        cost_swapped = (0 if secondary_model == model_left_id else 1) + \
                       (0 if primary_model == model_right_id else 1)

        if cost_swapped < cost_straight:
            print(f"🔀 Smart Router: Swapping lanes to optimize cache!")
            # Update state for next time
            self.lane_state["secondary"] = model_left_id
            self.lane_state["primary"] = model_right_id
            return "secondary", "primary"
        else:
            print(f"⬇️ Smart Router: keeping straight lanes.")
            # Update state for next time
            self.lane_state["primary"] = model_left_id
            self.lane_state["secondary"] = model_right_id
            return "primary", "secondary"

# Create a global instance
router = SmartRouter()

# --- STEP 3: REWRITE call_modal_api FOR STREAMING ---
def call_modal_api(model_repo_id, prompt, lane):
    """

    Calls the Modal API on a specific lane and yields tokens as they arrive.

    This is now a GENERATOR.

    """
    if not model_repo_id:
        yield "Please select a model from the dropdown."
        return # Stop the generator
    
    if not config.MY_AUTH_TOKEN:
        yield "Error: `ARENA_AUTH_TOKEN` is not set on the Gradio server."
        return

    # Construct the URL based on the lane
    if lane == "primary":
        endpoint = f"{config.MODAL_BASE_URL}-generate-primary.modal.run"
    else:
        endpoint = f"{config.MODAL_BASE_URL}-generate-secondary.modal.run"

    print(f"🚀 Streaming from {model_repo_id} on [{lane.upper()}]...")

    headers = {
        "Content-Type": "application/json", 
        "Authorization": f"Bearer {config.MY_AUTH_TOKEN}"
    }
    payload = {"model_id": model_repo_id, "prompt": prompt}
    
    try:
        # stream=True is the magic.
        response = requests.post(
            endpoint, 
            json=payload, 
            timeout=300, 
            headers=headers, 
            stream=True
        )
        response.raise_for_status() 
        
        # Yield tokens as they arrive
        for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
            if chunk:
                yield chunk
                
    except requests.exceptions.RequestException as e:
        if e.response and e.response.status_code == 401:
            yield "Error: Authentication failed. The token is invalid."
        elif e.response:
             # Try to get error detail from the streaming API
            try:
                error_detail = e.response.json().get("detail", str(e))
                yield f"API Error: {e.response.status_code} - {error_detail}"
            except:
                yield f"API Error: {e}"
        else:
            yield f"API Error: {e}"
    except Exception as e:
        yield f"An unexpected error occurred: {e}"