Ark-kun commited on
Commit
73520b7
·
1 Parent(s): 3d0c3c0

feat: Support multi-tenant mode

Browse files
huggingface_overlay/start_HuggingFace.py CHANGED
@@ -2,8 +2,23 @@ import os
2
 
3
  __all__ = ["app"]
4
 
5
- print("Starting single-tenant mode")
6
- from start_HuggingFace_single_tenant import app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  if __name__ == "__main__":
9
  import uvicorn
 
2
 
3
  __all__ = ["app"]
4
 
5
+ _MULTI_TENANT_SPACE_IDS = [
6
+ "TangleML/tangle",
7
+ "TangleML/tangle_multi_tenant",
8
+ "Ark-kun/tangle_multi_tenant",
9
+ ]
10
+
11
+ _is_multi_tenant = (
12
+ os.environ.get("MULTI_TENANT", "false").lower() == "true"
13
+ or os.environ.get("SPACE_ID") in _MULTI_TENANT_SPACE_IDS
14
+ )
15
+
16
+ if _is_multi_tenant:
17
+ print("Starting multi-tenant mode")
18
+ from start_HuggingFace_multi_tenant import app
19
+ else:
20
+ print("Starting single-tenant mode")
21
+ from start_HuggingFace_single_tenant import app
22
 
23
  if __name__ == "__main__":
24
  import uvicorn
huggingface_overlay/start_HuggingFace_multi_tenant.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from collections import abc
3
+ import dataclasses
4
+ import json
5
+ import logging
6
+ import os
7
+ import pathlib
8
+ import threading
9
+ import traceback
10
+ import typing
11
+ from typing import Any
12
+
13
+
14
+ import fastapi
15
+ import fastapi.routing
16
+ from fastapi import staticfiles
17
+ import huggingface_hub
18
+ import huggingface_hub.errors
19
+ import sqlalchemy
20
+ from sqlalchemy import orm
21
+
22
+
23
+ from cloud_pipelines.orchestration.storage_providers import (
24
+ interfaces as storage_interfaces,
25
+ )
26
+
27
+ # from cloud_pipelines_backend import api_router_multi_tenant
28
+ # from cloud_pipelines_backend import api_router_multi_tenant as api_router
29
+ from cloud_pipelines_backend import api_router
30
+ from cloud_pipelines_backend import database_ops
31
+ from cloud_pipelines_backend import orchestrator_sql
32
+ from cloud_pipelines_backend.launchers import huggingface_launchers
33
+ from cloud_pipelines_backend.launchers import local_docker_launchers
34
+ from cloud_pipelines_backend.launchers import interfaces as launcher_interfaces
35
+
36
+ # region: Logging configuration
37
+ import logging.config
38
+
39
+ LOGGING_CONFIG = {
40
+ "version": 1,
41
+ "disable_existing_loggers": True,
42
+ "formatters": {
43
+ "standard": {"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s"},
44
+ },
45
+ "handlers": {
46
+ "default": {
47
+ # "level": "INFO",
48
+ "level": "DEBUG",
49
+ "formatter": "standard",
50
+ "class": "logging.StreamHandler",
51
+ "stream": "ext://sys.stderr",
52
+ },
53
+ },
54
+ "loggers": {
55
+ # root logger
56
+ "": {
57
+ "level": "INFO",
58
+ "handlers": ["default"],
59
+ "propagate": False,
60
+ },
61
+ __name__: {
62
+ "level": "DEBUG",
63
+ "handlers": ["default"],
64
+ "propagate": False,
65
+ },
66
+ "cloud_pipelines_backend.orchestrator_sql": {
67
+ "level": "DEBUG",
68
+ "handlers": ["default"],
69
+ "propagate": False,
70
+ },
71
+ "cloud_pipelines_backend.launchers.huggingface_launchers": {
72
+ "level": "DEBUG",
73
+ "handlers": ["default"],
74
+ "propagate": False,
75
+ },
76
+ "cloud_pipelines.orchestration.launchers.local_docker_launchers": {
77
+ "level": "DEBUG",
78
+ "handlers": ["default"],
79
+ "propagate": False,
80
+ },
81
+ "uvicorn.error": {
82
+ "level": "DEBUG",
83
+ "handlers": ["default"],
84
+ # Fix triplicated log messages
85
+ "propagate": False,
86
+ },
87
+ "uvicorn.access": {
88
+ # "level": "DEBUG",
89
+ "level": "INFO", # Skip successful GET requests. Does not work...
90
+ "handlers": ["default"],
91
+ },
92
+ "watchfiles.main": {
93
+ "level": "WARNING",
94
+ "handlers": ["default"],
95
+ },
96
+ },
97
+ }
98
+
99
+ logging.config.dictConfig(LOGGING_CONFIG)
100
+
101
+
102
+
103
+ # We want to reduce noise in the HTTP access logs by removing records for successful HTTP GET requests (status code 200)
104
+ class FilterOutSuccessfulHttpGetRequests(logging.Filter):
105
+ def filter(self, record: logging.LogRecord) -> bool:
106
+ # Uvicorn access logs store the status code in record.args[2] (as a string)
107
+ # print(f"{record.args=}, {record=}")
108
+ # record.args:
109
+ # IP, HTTP method, URL, HTTP version, HTTP status code (int)
110
+ try:
111
+ # if record.args[1] == "GET" and 200 <= int(record.args[4]) < 300:
112
+ if record.args[1] == "GET" and int(record.args[4]) in (200, 304):
113
+ record.levelname = "DEBUG"
114
+ # We have to filter out the message here since the dict config fails to filter out the DEBUG messages.
115
+ return False
116
+ except:
117
+ pass
118
+ return True
119
+
120
+
121
+ # Add the filter to the "uvicorn.access" logger
122
+ logging.getLogger("uvicorn.access").addFilter(FilterOutSuccessfulHttpGetRequests())
123
+
124
+
125
+ logger = logging.getLogger(__name__)
126
+ # endregion
127
+
128
+ ENABLE_HUGGINGFACE_AUTH = True
129
+
130
+ MULTI_TENANT_SPACE_IDS = [
131
+ "TangleML/tangle",
132
+ "TangleML/tangle_multi_tenant",
133
+ "Ark-kun/tangle_multi_tenant",
134
+ ]
135
+
136
+ hf_space_id = os.environ.get("SPACE_ID")
137
+ is_hf_space = bool(hf_space_id)
138
+ is_multi_tenant = (
139
+ os.environ.get("MULTI_TENANT", "false").lower() == "true"
140
+ or hf_space_id in MULTI_TENANT_SPACE_IDS
141
+ )
142
+ use_local_debug = not is_hf_space
143
+ # use_local_launcher = use_local_debug
144
+ use_local_launcher = False
145
+
146
+ if use_local_debug:
147
+ is_multi_tenant = True
148
+
149
+ logger.info(f"{hf_space_id=}")
150
+ logger.info(f"{is_hf_space=}")
151
+ logger.info(f"{is_multi_tenant=}")
152
+ logger.info(f"{use_local_debug=}")
153
+ logger.info(f"{use_local_launcher=}")
154
+
155
+
156
+ # region Paths configuration
157
+
158
+ if is_hf_space:
159
+ root_data_dir = "/data/tangle/data_multi_tenant"
160
+ else:
161
+ root_data_dir = "./data_multi_tenant/"
162
+
163
+ root_data_dir_path_obj = pathlib.Path(root_data_dir).resolve()
164
+ root_data_dir_path = str(root_data_dir_path_obj)
165
+ logger.info(f"{root_data_dir_path=}")
166
+
167
+ root_data_dir_path_obj.mkdir(parents=True, exist_ok=True)
168
+ # endregion
169
+
170
+ # region: DB Configuration
171
+ tenant_dir_template_path_obj = root_data_dir_path_obj / "tenants" / "{tenant_id}"
172
+ database_path_template = str(tenant_dir_template_path_obj / "db.sqlite")
173
+ tenants_database_path_obj = root_data_dir_path_obj / "tenants.sqlite"
174
+ tenants_database_path = str(tenants_database_path_obj)
175
+ tenants_database_uri = f"sqlite:///{tenants_database_path}"
176
+
177
+ logger.info(f"{database_path_template=}")
178
+ logger.info(f"{tenants_database_uri=}")
179
+
180
+ tenants_database_path_obj.parent.mkdir(parents=True, exist_ok=True)
181
+ # endregion
182
+
183
+ # region: Orchestrator configuration
184
+ default_task_annotations: dict[str, str] = {}
185
+ sleep_seconds_between_queue_sweeps: float = 5.0
186
+ # endregion
187
+
188
+ # region: Authentication configuration
189
+
190
+ logger.info("os.environ=" + json.dumps(dict(os.environ), indent=2, sort_keys=True))
191
+
192
+ if "HF_TOKEN" in os.environ and is_multi_tenant:
193
+ logger.warning("Warning: Multi-tenant spaces should not have HF_TOKEN set.")
194
+
195
+
196
+ # ! This module is executed during server startup. We do not know any tenants at this point.
197
+ # So, we can almost nothing (cannot check/create repo, cannot construct launcher etc).
198
+
199
+
200
+ @dataclasses.dataclass
201
+ class TenantIdNameToken:
202
+ id: str
203
+ namespace: str
204
+ token: str
205
+
206
+
207
+ # ! parse_huggingface_oauth handles local (non-Space) auth - returned info corresponding to the local HF_TOKEN.
208
+ def get_tenant_info_for_active_user_or_die(
209
+ request: fastapi.Request,
210
+ ) -> TenantIdNameToken:
211
+ oauth_info = huggingface_hub.parse_huggingface_oauth(request)
212
+ if not oauth_info:
213
+ # No -- TODO: Maybe return the demo tenant info?
214
+ raise fastapi.HTTPException(
215
+ status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
216
+ detail="Unauthenticated user",
217
+ )
218
+ # logger.debug(f"get_tenant_info_for_active_user_or_die: {oauth_info=}")
219
+ tenant_info = TenantIdNameToken(
220
+ id=oauth_info.user_info.sub,
221
+ namespace=oauth_info.user_info.preferred_username,
222
+ token=oauth_info.access_token,
223
+ )
224
+ # logger.debug(f"get_tenant_info_for_active_user_or_die: {tenant_info=}")
225
+ return tenant_info
226
+
227
+
228
+ # We're multi-tenant, but single-user (single user per tenant).
229
+ # So the user is always an admin of it's own tenant.
230
+ # TODO: Create multi-user mode.
231
+ # TODO: Enable demo tenants that are readable to anyone.
232
+ def get_user_details(request: fastapi.Request):
233
+ oauth_info = huggingface_hub.parse_huggingface_oauth(request)
234
+ logger.debug(f"get_user_details: {oauth_info=}")
235
+ if oauth_info:
236
+ # We're multi-tenant, but single-user (single user per tenant).
237
+ # So the user is always an admin of it's own tenant.
238
+ # TODO: Create multi-user mode.
239
+ # TODO: Enable demo tenants that are readable to anyone.
240
+ user_details = api_router.UserDetails(
241
+ name=oauth_info.user_info.preferred_username,
242
+ permissions=api_router.Permissions(
243
+ read=True,
244
+ write=True,
245
+ admin=True,
246
+ ),
247
+ )
248
+ logger.debug(f"get_user_details: {user_details=}")
249
+ return user_details
250
+ # FIX: ???!!!
251
+ # Redirect to login?
252
+
253
+ # Return unauthenticated
254
+ return api_router.UserDetails(
255
+ # name="anonymous",
256
+ name=None,
257
+ permissions=api_router.Permissions(
258
+ read=False,
259
+ write=False,
260
+ admin=False,
261
+ ),
262
+ )
263
+
264
+
265
+ # TODO: Switch to async-supporting locks
266
+ db_engines_lock = threading.Lock()
267
+ db_engines: dict[str, sqlalchemy.Engine] = {}
268
+
269
+
270
+ def get_db_engine_for_tenant(tenant_id: str) -> sqlalchemy.Engine:
271
+ db_engine = db_engines.get(tenant_id)
272
+ if db_engine:
273
+ return db_engine
274
+ with db_engines_lock:
275
+ # Double-checked locking
276
+ db_engine = db_engines.get(tenant_id)
277
+ if db_engine:
278
+ return db_engine
279
+ database_path = database_path_template.format(tenant_id=tenant_id)
280
+ database_uri = "sqlite:///" + database_path
281
+ pathlib.Path(database_path).parent.mkdir(parents=True, exist_ok=True)
282
+
283
+ # TODO: Implement "create DB on first write" optimization
284
+ db_engine = database_ops.create_db_engine_and_migrate_db(
285
+ database_uri=database_uri
286
+ )
287
+ db_engines[tenant_id] = db_engine
288
+ return db_engine
289
+
290
+
291
+ def get_db_engine_for_unauthenticated() -> sqlalchemy.Engine:
292
+ return database_ops.create_db_engine_and_migrate_db(
293
+ database_uri="sqlite:///:memory:"
294
+ )
295
+
296
+
297
+ def get_session_factory_for_active_user(
298
+ request: fastapi.Request,
299
+ ) -> typing.Callable[[], orm.Session]:
300
+ try:
301
+ tenant_info = get_tenant_info_for_active_user_or_die(request=request)
302
+ db_engine = get_db_engine_for_tenant(tenant_id=tenant_info.id)
303
+ except:
304
+ logger.debug(
305
+ f"get_session_factory_for_active_user: User is unauthenticated. Returning ephemeral in-memory DB engine."
306
+ )
307
+ db_engine = get_db_engine_for_unauthenticated()
308
+ return orm.sessionmaker(bind=db_engine, autoflush=False)
309
+
310
+
311
+ def get_session_generator_for_active_user(
312
+ request: fastapi.Request,
313
+ ) -> abc.Iterator[orm.Session]:
314
+ session_factory = get_session_factory_for_active_user(request=request)
315
+ with session_factory() as session:
316
+ yield session
317
+
318
+
319
+ def get_launcher_for_tenant(
320
+ tenant_id: str, tenant_namespace: str, tenant_token: str
321
+ ) -> launcher_interfaces.ContainerTaskLauncher[launcher_interfaces.LaunchedContainer]:
322
+ del tenant_id
323
+ if use_local_launcher:
324
+ launcher = local_docker_launchers.DockerContainerLauncher()
325
+ else:
326
+ launcher = huggingface_launchers.HuggingFaceJobsContainerLauncher(
327
+ namespace=tenant_namespace,
328
+ hf_token=tenant_token,
329
+ hf_job_token=tenant_token,
330
+ )
331
+ return launcher
332
+
333
+
334
+ def get_launcher_for_active_user(
335
+ request: fastapi.Request,
336
+ ) -> launcher_interfaces.ContainerTaskLauncher[launcher_interfaces.LaunchedContainer]:
337
+ tenant_info = get_tenant_info_for_active_user_or_die(request=request)
338
+ return get_launcher_for_tenant(
339
+ tenant_id=tenant_info.id,
340
+ tenant_namespace=tenant_info.namespace,
341
+ tenant_token=tenant_info.token,
342
+ )
343
+
344
+
345
+ @dataclasses.dataclass(kw_only=True)
346
+ class OrchestratorInfo:
347
+ tenant_id: str
348
+ tenant_namespace: str
349
+ tenant_token: str
350
+ artifacts_root_uri: str
351
+ logs_root_uri: str
352
+ launcher: launcher_interfaces.ContainerTaskLauncher[
353
+ launcher_interfaces.LaunchedContainer
354
+ ]
355
+ storage_provider: storage_interfaces.StorageProvider
356
+ orchestrator: orchestrator_sql.OrchestratorService_Sql
357
+ orchestrator_thread: threading.Thread
358
+
359
+
360
+ def start_orchestrator_for_tenant(
361
+ tenant_id: str,
362
+ tenant_namespace: str,
363
+ tenant_token: str,
364
+ update_tenants_db: bool = True,
365
+ ) -> OrchestratorInfo:
366
+ logger.info(f"tenant={tenant_namespace}({tenant_id}): Preparing the orchestrator")
367
+ launcher = get_launcher_for_tenant(
368
+ tenant_id=tenant_id,
369
+ tenant_namespace=tenant_namespace,
370
+ tenant_token=tenant_token,
371
+ )
372
+
373
+ db_engine = get_db_engine_for_tenant(tenant_id=tenant_id)
374
+
375
+ if use_local_launcher:
376
+ artifacts_root_uri = str(tenant_dir_template_path_obj / "artifacts").format(
377
+ tenant_id=tenant_id
378
+ )
379
+ logs_root_uri = str(tenant_dir_template_path_obj / "logs").format(
380
+ tenant_id=tenant_id
381
+ )
382
+
383
+ from cloud_pipelines.orchestration.storage_providers import local_storage
384
+
385
+ storage_provider = local_storage.LocalStorageProvider()
386
+ else:
387
+ # Create artifact repo if it does not exist.
388
+ artifacts_repo_id = f"{tenant_namespace}/tangle_data"
389
+ # Do not pollute repos with debug data
390
+ if use_local_debug:
391
+ artifacts_repo_id += "_test"
392
+ ensure_artifact_repo_exists(
393
+ artifacts_repo_id=artifacts_repo_id, token=tenant_token
394
+ )
395
+ artifacts_root_uri = f"hf://datasets/{artifacts_repo_id}/data"
396
+ logs_root_uri = artifacts_root_uri
397
+
398
+ from cloud_pipelines_backend.storage_providers import huggingface_repo_storage
399
+
400
+ # ! Need to pass proper token here!
401
+ hf_client = huggingface_hub.HfApi(token=tenant_token)
402
+ storage_provider = huggingface_repo_storage.HuggingFaceRepoStorageProvider(
403
+ client=hf_client
404
+ )
405
+
406
+ session_factory = orm.sessionmaker(
407
+ autocommit=False, autoflush=False, bind=db_engine
408
+ )
409
+
410
+ # With autobegin=False you always need to begin a transaction, even to query the DB.
411
+ session_factory = orm.sessionmaker(
412
+ autocommit=False, autoflush=False, bind=db_engine
413
+ )
414
+ orchestrator = orchestrator_sql.OrchestratorService_Sql(
415
+ session_factory=session_factory,
416
+ launcher=launcher,
417
+ storage_provider=storage_provider,
418
+ data_root_uri=artifacts_root_uri,
419
+ logs_root_uri=logs_root_uri,
420
+ default_task_annotations=default_task_annotations,
421
+ sleep_seconds_between_queue_sweeps=sleep_seconds_between_queue_sweeps,
422
+ )
423
+ logger.info(f"tenant={tenant_namespace}({tenant_id}): Starting the orchestrator")
424
+ orchestrator_thread = threading.Thread(
425
+ target=orchestrator.run_loop,
426
+ daemon=True,
427
+ )
428
+ orchestrator_thread.start()
429
+
430
+ if update_tenants_db:
431
+ # Recording the orchestrator_info in the tenants DB
432
+ with orm.Session(bind=tenants_db_engine) as session:
433
+ tenant_row = session.get(TenantRow, tenant_id)
434
+ if tenant_row:
435
+ tenant_row.orchestrator_active = True
436
+ launcher_class = type(launcher)
437
+ storage_provider_class = type(storage_provider)
438
+ launcher_class_name = (
439
+ f"{launcher_class.__module__}.{launcher_class.__qualname__}"
440
+ )
441
+ storage_provider_class_name = f"{storage_provider_class.__module__}.{storage_provider_class.__qualname__}"
442
+ tenant_row.orchestrator_config = dict(
443
+ storage_provider_class_name=storage_provider_class_name,
444
+ artifacts_root_uri=artifacts_root_uri,
445
+ logs_root_uri=logs_root_uri,
446
+ launcher_class_name=launcher_class_name,
447
+ )
448
+ session.commit()
449
+ else:
450
+ logging.critical(
451
+ f"start_orchestrator_for_tenant: Started the orchestrator for {tenant_id=}, but tenants DB has no such tenant."
452
+ )
453
+
454
+ return OrchestratorInfo(
455
+ tenant_id=tenant_id,
456
+ tenant_namespace=tenant_namespace,
457
+ tenant_token=tenant_token,
458
+ artifacts_root_uri=artifacts_root_uri,
459
+ logs_root_uri=logs_root_uri,
460
+ launcher=launcher,
461
+ storage_provider=storage_provider,
462
+ orchestrator=orchestrator,
463
+ orchestrator_thread=orchestrator_thread,
464
+ )
465
+
466
+
467
+ # TODO: Switch to async-supporting locks
468
+ orchestrators_lock = threading.Lock()
469
+ orchestrators: dict[str, OrchestratorInfo] = {}
470
+
471
+
472
+ def get_or_start_orchestrator(
473
+ tenant_id: str, tenant_namespace: str, tenant_token: str
474
+ ) -> OrchestratorInfo:
475
+ orchestrator_info = orchestrators.get(tenant_id)
476
+ if orchestrator_info:
477
+ return orchestrator_info
478
+ with orchestrators_lock:
479
+ # Double-checked locking
480
+ orchestrator_info = orchestrators.get(tenant_id)
481
+ if orchestrator_info:
482
+ return orchestrator_info
483
+ orchestrator_info = start_orchestrator_for_tenant(
484
+ tenant_id=tenant_id,
485
+ tenant_namespace=tenant_namespace,
486
+ tenant_token=tenant_token,
487
+ )
488
+ orchestrators[tenant_id] = orchestrator_info
489
+ return orchestrator_info
490
+
491
+
492
+ def get_or_start_orchestrator_for_active_user(
493
+ request: fastapi.Request,
494
+ ) -> OrchestratorInfo:
495
+ tenant_info = get_tenant_info_for_active_user_or_die(request=request)
496
+ orchestrator_info = get_or_start_orchestrator(
497
+ tenant_id=tenant_info.id,
498
+ tenant_namespace=tenant_info.namespace,
499
+ tenant_token=tenant_info.token,
500
+ )
501
+ with orm.Session(bind=tenants_db_engine) as session:
502
+ tenant_row = session.get(TenantRow, tenant_info.id)
503
+ if not tenant_row:
504
+ tenant_row = update_tenant_info_in_db(request=request)
505
+ tenant_row.orchestrator_active = True
506
+ session.commit()
507
+
508
+ return orchestrator_info
509
+
510
+
511
+ def ensure_artifact_repo_exists(artifacts_repo_id: str, token: str):
512
+ hf_client = huggingface_hub.HfApi(token=token)
513
+ repo_type = "dataset"
514
+ repo_exists = False
515
+ try:
516
+ _ = hf_client.repo_info(
517
+ repo_id=artifacts_repo_id,
518
+ repo_type=repo_type,
519
+ )
520
+ repo_exists = True
521
+ logger.debug(
522
+ f"ensure_artifact_repo_exists: Artifact repo exists: {artifacts_repo_id}"
523
+ )
524
+
525
+ except huggingface_hub.errors.RepositoryNotFoundError:
526
+ pass
527
+ except Exception as ex:
528
+ raise RuntimeError(
529
+ f"Error checking for the artifacts repo existence. {artifacts_repo_id=}"
530
+ ) from ex
531
+ if not repo_exists:
532
+ logger.info(
533
+ f"ensure_artifact_repo_exists: Artifact repo does not exist. Creating it: {artifacts_repo_id}"
534
+ )
535
+ try:
536
+ _ = hf_client.create_repo(
537
+ repo_id=artifacts_repo_id,
538
+ repo_type=repo_type,
539
+ private=True,
540
+ exist_ok=True,
541
+ )
542
+ except Exception as ex:
543
+ raise RuntimeError(
544
+ f"Error creating the artifacts repo. {artifacts_repo_id=}"
545
+ ) from ex
546
+
547
+
548
+ def do_stuff_for_tenant(tenant_id: str, tenant_namespace: str):
549
+ del tenant_id
550
+ del tenant_namespace
551
+ # Don't initialize library for HuggingFace users. The initialization might require re-design.
552
+ # # The default library must be initialized here, not when adding the Component Library routes.
553
+ # # Otherwise the tables won't yet exist when initialization is performed.
554
+ # from cloud_pipelines_backend import component_library_api_server as components_api
555
+ # component_library_service = components_api.ComponentLibraryService()
556
+ # db_engine = get_db_engine_for_tenant(tenant_id=tenant_id)
557
+ # with orm.Session(bind=db_engine) as session:
558
+ # component_library_service._initialize_empty_default_library_if_missing(
559
+ # session=session,
560
+ # published_by=tenant_namespace,
561
+ # )
562
+ pass
563
+
564
+
565
+ from sqlalchemy.ext import mutable
566
+
567
+
568
+ class _TenantTableBase(orm.MappedAsDataclass, orm.DeclarativeBase, kw_only=True):
569
+ # Not really needed due to kw_only=True
570
+ _: dataclasses.KW_ONLY
571
+
572
+ # The mutable.MutableDict.as_mutable construct ensures that changes to dictionaries are picked up.
573
+ # This is very important when making changes to `extra_data` dictionaries.
574
+ type_annotation_map = {
575
+ dict: mutable.MutableDict.as_mutable(sqlalchemy.JSON),
576
+ list: mutable.MutableList.as_mutable(sqlalchemy.JSON),
577
+ dict[str, Any]: mutable.MutableDict.as_mutable(sqlalchemy.JSON),
578
+ str: sqlalchemy.String(255),
579
+ }
580
+
581
+
582
+ class TenantRow(_TenantTableBase):
583
+ __tablename__ = "tenant"
584
+
585
+ id: orm.Mapped[str] = orm.mapped_column(primary_key=True)
586
+ name: orm.Mapped[str]
587
+ access_token: orm.Mapped[str]
588
+ oauth_info: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None)
589
+ orchestrator_active: orm.Mapped[bool] = orm.mapped_column(default=False, index=True)
590
+ orchestrator_config: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(
591
+ default=None
592
+ )
593
+ # user_info: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None)
594
+ extra_data: orm.Mapped[dict[str, Any] | None] = orm.mapped_column(default=None)
595
+
596
+
597
+ tenants_db_engine = sqlalchemy.create_engine(url=tenants_database_uri)
598
+
599
+
600
+ def init_tenants_db():
601
+ # tenants_db_engine = sqlalchemy.create_engine(url=tenants_database_uri)
602
+ TenantRow.__table__.create(tenants_db_engine, checkfirst=True)
603
+
604
+
605
+ def update_tenant_info_in_db(request: fastapi.Request) -> TenantRow:
606
+ oauth_info: dict[str, Any] = request.session.get("oauth_info")
607
+ if not oauth_info:
608
+ raise ValueError(
609
+ f"update_tenant_info_in_db: request.session does not have oauth_info."
610
+ )
611
+ logger.debug(f"update_tenant_info_in_db: {oauth_info=}")
612
+ oauth_user_info: dict[str, Any] = oauth_info["userinfo"]
613
+ token: str = oauth_info["access_token"]
614
+ huggingface_user_info: dict[str, Any] = huggingface_hub.whoami(token=token)
615
+ # In local mode, `oauth_user_info["sub"]`` is always "0123456789"
616
+ # We could get the correct ID from `huggingface_user_info["id"]`
617
+ # but huggingface_user_info is not available from the session cookie, so there would be discrepancies between IDs.
618
+ # So, we use `oauth_user_info["sub"]`
619
+ id = oauth_user_info["sub"]
620
+ # id = huggingface_user_info["id"]
621
+ logger.debug(f"update_tenant_info_in_db: {huggingface_user_info=}")
622
+ with orm.Session(bind=tenants_db_engine, expire_on_commit=False) as session:
623
+ # session.merge seems to delete unspecified info (like orchestrator_config) at some circumstances
624
+ # result_row = session.merge(tenant_row)
625
+ tenant_row = session.get(TenantRow, id)
626
+ if tenant_row:
627
+ tenant_row.name = oauth_user_info["preferred_username"]
628
+ tenant_row.access_token = oauth_info["access_token"]
629
+ else:
630
+ tenant_row = TenantRow(
631
+ id=id,
632
+ name=oauth_user_info["preferred_username"],
633
+ access_token=oauth_info["access_token"],
634
+ )
635
+ tenant_row.oauth_info = oauth_info
636
+ extra_data = tenant_row.extra_data or {}
637
+ extra_data["huggingface_user_info"] = huggingface_user_info
638
+ tenant_row.extra_data = dict(extra_data)
639
+
640
+ session.commit()
641
+ session.expunge(tenant_row)
642
+ return tenant_row
643
+
644
+
645
+ def start_all_active_tenant_orchestrators():
646
+ logger.debug(f"start_all_active_tenant_orchestrators")
647
+ with orm.Session(bind=tenants_db_engine) as session:
648
+ for tenant_row in session.scalars(
649
+ sqlalchemy.select(TenantRow).where(TenantRow.orchestrator_active)
650
+ ):
651
+ # TODO: Respect the orchestrator_config
652
+ _ = get_or_start_orchestrator(
653
+ tenant_id=tenant_row.id,
654
+ tenant_namespace=tenant_row.name,
655
+ tenant_token=tenant_row.access_token,
656
+ )
657
+
658
+
659
+ # region: API Server initialization
660
+ @contextlib.asynccontextmanager
661
+ async def lifespan(app: fastapi.FastAPI):
662
+ init_tenants_db()
663
+ start_all_active_tenant_orchestrators()
664
+ yield
665
+ if tenants_db_engine:
666
+ tenants_db_engine.dispose()
667
+
668
+
669
+ app = fastapi.FastAPI(
670
+ title="Cloud Pipelines API",
671
+ version="0.0.1",
672
+ separate_input_output_schemas=False,
673
+ lifespan=lifespan,
674
+ )
675
+
676
+
677
+ @app.exception_handler(Exception)
678
+ def handle_error(request: fastapi.Request, exc: BaseException):
679
+ exception_str = traceback.format_exception(type(exc), exc, exc.__traceback__)
680
+ return fastapi.responses.JSONResponse(
681
+ status_code=503,
682
+ content={"exception": exception_str},
683
+ )
684
+
685
+
686
+ def handle_pipeline_run_creation(request: fastapi.Request):
687
+ # Do nothing before PipelineRun is created
688
+ yield
689
+ # Wake up the orchestrator after the PipelineRun is created
690
+ _ = get_or_start_orchestrator_for_active_user(request=request)
691
+
692
+
693
+ api_router._setup_routes_internal(
694
+ app=app,
695
+ get_session=get_session_generator_for_active_user,
696
+ user_details_getter=get_user_details,
697
+ # TODO: Add
698
+ # container_launcher_for_log_streaming=launcher,
699
+ # TODO: Handle the default library
700
+ # default_component_library_owner_username=default_component_library_owner_username,
701
+ pipeline_run_creation_hook=handle_pipeline_run_creation,
702
+ )
703
+
704
+
705
+ # Health check needed by the Web app
706
+ @app.get("/services/ping")
707
+ def health_check():
708
+ return {}
709
+
710
+
711
+ if ENABLE_HUGGINGFACE_AUTH:
712
+ if "OAUTH_CLIENT_SECRET" not in os.environ:
713
+ logger.warning(
714
+ "HuggingFace auth is enabled, but OAUTH_CLIENT_SECRET env variable is is missing."
715
+ )
716
+ huggingface_hub.attach_huggingface_oauth(app, route_prefix="/api/")
717
+
718
+ # Hook the login callback route to write info to the tenants DB
719
+ # The route is created by huggingface_hub.attach_huggingface_oauth, so we cannot easily control it.
720
+ auth_callback_route_candidates = [
721
+ route
722
+ for route in typing.cast(list[fastapi.routing.APIRoute], app.routes)
723
+ if route.path.endswith("/oauth/huggingface/callback")
724
+ ]
725
+ if len(auth_callback_route_candidates) != 1:
726
+ raise ValueError(f"{auth_callback_route_candidates=}")
727
+ if auth_callback_route_candidates:
728
+ auth_callback_route = auth_callback_route_candidates[0]
729
+ auth_callback_original = auth_callback_route.endpoint
730
+ assert auth_callback_route.dependant.call == auth_callback_route.endpoint
731
+
732
+ def wrapped_auth_callback(*args, **kwargs) -> Any:
733
+ # logger.debug(f"wrapped_auth_callback: {args=}, {kwargs=}")
734
+ result = auth_callback_original(*args, **kwargs)
735
+ request: fastapi.Request = kwargs.get("request") or args[0]
736
+ if "oauth_info" in request.session:
737
+ update_tenant_info_in_db(request=request)
738
+ return result
739
+
740
+ # The `ApiRoute.dependant.call` is the function being called, not the ApiRoute.endpoint.
741
+ # auth_callback_route.endpoint = wrapped_auth_callback
742
+ auth_callback_route.dependant.call = wrapped_auth_callback
743
+
744
+
745
+ # Mounting the web app if the files exist
746
+ this_dir = pathlib.Path(__file__).parent
747
+ web_app_search_dirs = [
748
+ this_dir / ".." / "pipeline-studio-app" / "build",
749
+ this_dir / ".." / "frontend" / "build",
750
+ this_dir / ".." / "frontend_build",
751
+ this_dir / "pipeline-studio-app" / "build",
752
+ ]
753
+ found_frontend_build_files = False
754
+ for web_app_dir in web_app_search_dirs:
755
+ if web_app_dir.exists():
756
+ found_frontend_build_files = True
757
+ logger.info(
758
+ f"Found the Web app static files at {str(web_app_dir)}. Mounting them."
759
+ )
760
+ # The Web app base URL is currently static and hardcoded.
761
+ # TODO: Remove this mount once the base URL becomes relative.
762
+ app.mount(
763
+ "/pipeline-studio-app/",
764
+ staticfiles.StaticFiles(directory=web_app_dir, html=True),
765
+ name="static",
766
+ )
767
+ app.mount(
768
+ "/",
769
+ staticfiles.StaticFiles(directory=web_app_dir, html=True),
770
+ name="static",
771
+ )
772
+ if not found_frontend_build_files:
773
+ logger.warning("The Web app files were not found. Skipping.")
774
+ # endregion
775
+
776
+ if __name__ == "__main__":
777
+ import uvicorn
778
+
779
+ uvicorn.run(app, host="127.0.0.1", port=8000)