Spaces:
Sleeping
Sleeping
File size: 1,528 Bytes
9ca7e5a 68f48a7 9ca7e5a 68f48a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | from __future__ import annotations
from typing import Optional
from fastapi import Depends, HTTPException, Query, Request
from api.classify.banks import LabelSetBank
from api.label_sets.registry import LabelSetRegistry
from api.model.clip_store import ClipStore
from api.classify.service import TwoStageClassifier
def get_request_id(request: Request) -> str:
return request.state.request_id
def get_label_set_hash(
label_set_hash: Optional[str] = Query(
default=None,
description="If omitted, uses the default activated label set.",
),
) -> Optional[str]:
return label_set_hash
def resolve_bank(registry: LabelSetRegistry, label_set_hash: Optional[str]) -> LabelSetBank:
try:
return registry.resolve(label_set_hash)
except KeyError as e:
msg = str(e)
if "No default" in msg:
raise HTTPException(status_code=400, detail="No default label set. Upload one and activate it.")
raise HTTPException(status_code=404, detail="Unknown label_set_hash")
def get_bank(
registry: LabelSetRegistry,
label_set_hash: Optional[str] = Depends(get_label_set_hash),
) -> LabelSetBank:
return resolve_bank(registry, label_set_hash)
def get_store(request: Request) -> ClipStore:
return request.app.state.resources.store
def get_classifier(request: Request) -> TwoStageClassifier:
return request.app.state.resources.classifier
def get_registry(request: Request) -> LabelSetRegistry:
return request.app.state.resources.registry
|