File size: 1,528 Bytes
9ca7e5a
 
 
 
 
 
68f48a7
 
 
 
9ca7e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68f48a7
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import annotations

from typing import Optional

from fastapi import Depends, HTTPException, Query, Request

from api.classify.banks import LabelSetBank
from api.label_sets.registry import LabelSetRegistry
from api.model.clip_store import ClipStore
from api.classify.service import TwoStageClassifier


def get_request_id(request: Request) -> str:
    return request.state.request_id


def get_label_set_hash(
    label_set_hash: Optional[str] = Query(
        default=None,
        description="If omitted, uses the default activated label set.",
    ),
) -> Optional[str]:
    return label_set_hash


def resolve_bank(registry: LabelSetRegistry, label_set_hash: Optional[str]) -> LabelSetBank:
    try:
        return registry.resolve(label_set_hash)
    except KeyError as e:
        msg = str(e)
        if "No default" in msg:
            raise HTTPException(status_code=400, detail="No default label set. Upload one and activate it.")
        raise HTTPException(status_code=404, detail="Unknown label_set_hash")


def get_bank(
    registry: LabelSetRegistry,
    label_set_hash: Optional[str] = Depends(get_label_set_hash),
) -> LabelSetBank:
    return resolve_bank(registry, label_set_hash)


def get_store(request: Request) -> ClipStore:
    return request.app.state.resources.store


def get_classifier(request: Request) -> TwoStageClassifier:
    return request.app.state.resources.classifier


def get_registry(request: Request) -> LabelSetRegistry:
    return request.app.state.resources.registry