File size: 4,513 Bytes
3e25ded
 
 
 
 
 
 
1c0aad9
 
3e25ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a1c55b
 
 
 
 
3e25ded
 
1c0aad9
905e1ea
 
 
 
 
 
 
 
bc67f56
 
 
3e25ded
 
 
 
 
 
 
 
 
 
 
 
 
1c0aad9
 
 
 
 
 
3e25ded
 
1c0aad9
3e25ded
1c0aad9
3e25ded
 
 
 
 
1c0aad9
 
 
 
 
 
 
 
07c2818
 
 
 
 
1c0aad9
3e25ded
 
 
 
 
dd25aa6
905e1ea
3e25ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import shutil
import threading
import subprocess

from blossomtune_gradio.logs import log
from blossomtune_gradio import config as cfg
from blossomtune_gradio import util
from blossomtune_gradio.database import SessionLocal, Config


# In-memory store for background processes and logs
process_store = {"superlink": None, "runner": None}


def run_process(command, process_key):
    """Generic function to run a background process and log its output."""
    global process_store
    log(f"[{process_key.title()}] Starting: {' '.join(command)}")
    try:
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            bufsize=1,
            universal_newlines=True,
        )
        process_store[process_key] = process
        for line in iter(process.stdout.readline, ""):
            log(f"[{process_key.title()}] {line.strip()}")
        process.wait()
    except Exception as e:
        log(f"[{process_key.title()}] CRITICAL ERROR: {e}")
    finally:
        log(f"[{process_key.title()}] Process finished.")
        process_store[process_key] = None


def start_superlink():
    # Do not start an internal process if in external mode.
    if cfg.SUPERLINK_MODE == "external":
        log.warning("start_superlink called while in external mode. Operation aborted.")
        return False, "Application is in external Superlink mode."

    if process_store["superlink"] and process_store["superlink"].poll() is None:
        return False, "Superlink process is already running."

    command = [
        shutil.which("flower-superlink"),
        "--ssl-ca-certfile",
        cfg.BLOSSOMTUNE_TLS_CA_CERTFILE,
        "--ssl-certfile",
        cfg.BLOSSOMTUNE_TLS_CERTFILE,
        "--ssl-keyfile",
        cfg.BLOSSOMTUNE_TLS_KEYFILE,
        "--auth-list-public-keys",
        cfg.AUTH_KEYS_CSV_PATH,
    ]
    threading.Thread(
        target=run_process, args=(command, "superlink"), daemon=True
    ).start()
    return True, "Superlink process started."


def start_runner(
    runner_app: str,
    run_id: str,
    num_partitions: str,
):
    if process_store["runner"] and process_store["runner"].poll() is None:
        return False, "A Runner process is already running."

    # Check if the Superlink is running, respecting the configured mode
    if cfg.SUPERLINK_MODE == "external":
        if not util.is_port_open(cfg.SUPERLINK_HOST, cfg.SUPERLINK_PORT):
            return False, "External Superlink is not running or unreachable."
    elif not (process_store["superlink"] and process_store["superlink"].poll() is None):
        return (
            False,
            "Internal Superlink is not running. Please start it before starting the runner.",
        )

    if not all([runner_app, run_id, num_partitions]):
        return False, "Please provide a Runner App, Run ID, and Total Partitions."
    if not num_partitions.isdigit() or int(num_partitions) <= 0:
        return False, "Total Partitions must be a positive integer."

    # Update the number of partitions in the database using SQLAlchemy
    with SessionLocal() as db:
        config_entry = db.query(Config).filter(Config.key == "num_partitions").first()
        if config_entry:
            config_entry.value = num_partitions
        else:
            db.add(Config(key="num_partitions", value=num_partitions))
        db.commit()

    runner_app_path = runner_app.replace(".", os.path.sep)
    if not os.path.exists(runner_app_path):
        return False, f"Unable to find app path '{runner_app_path}'."

    # Construct the command for a TLS-enabled runner
    command = [
        shutil.which("flwr"),
        "run",
        runner_app_path,
        "local-deployment",
        "--federation-config",
        f'address="{cfg.SUPERLINK_HOST}:{cfg.SUPERLINK_CONTROL_API_PORT}" root-certificates="{cfg.BLOSSOMTUNE_TLS_CA_CERTFILE}"',
        "--stream",
    ]
    threading.Thread(target=run_process, args=(command, "runner"), daemon=True).start()
    return True, "Federation Run is starting...."


def stop_process(
    process_key: str,
):
    process = process_store.get(process_key)
    if process and process.poll() is None:
        process.terminate()
        process.wait()
        log(f"[{process_key.title()}] Process stopped by user.")
        process_store[process_key] = None
    else:
        log(
            f"[{process_key.title()}] Stop command received, but no process was running."
        )