Spaces:
Paused
Paused
File size: 5,013 Bytes
3f7dd83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import os
import logging
import json
from tqdm import tqdm
from argparse import ArgumentParser
from model import YOLO
from label_studio_sdk.client import LabelStudio
from label_studio_ml.response import ModelResponse
LABEL_STUDIO_URL = os.getenv("LABEL_STUDIO_URL", "http://localhost:8080")
LABEL_STUDIO_API_KEY = os.getenv("LABEL_STUDIO_API_KEY", "your_api_key")
PROJECT_ID = os.getenv("LABEL_STUDIO_PROJECT_ID", "1")
logger = logging.getLogger(__name__)
def arg_parser():
parser = ArgumentParser(description="YOLO client for Label Studio ML Backend")
parser.add_argument(
"--ls-url", type=str, default=LABEL_STUDIO_URL, help="Label Studio URL"
)
parser.add_argument(
"--ls-api-key",
type=str,
default=LABEL_STUDIO_API_KEY,
help="Label Studio API Key",
)
parser.add_argument(
"--project", type=str, default="1", help="Label Studio Project ID"
)
parser.add_argument(
"--tasks",
type=str,
default="tasks.json",
help="Path to tasks JSON file with list of ids or task datas. Example: tasks.json\n"
"String with ids separated by comma: if you provide task ids, "
"task data will be downloaded automatically from the Label Studio instance. Example: 1,2,3",
)
return parser.parse_args()
class LabelStudioMLPredictor:
def __init__(self, ls_url, ls_api_key):
self.ls = LabelStudio(base_url=ls_url, api_key=ls_api_key)
logger.info(f"Successfully connected to Label Studio: {ls_url}")
def run(self, project, tasks):
# initialize Label Studio SDK client
ls = self.ls
project = ls.projects.get(id=project)
logger.info(f"Project is retrieved: {project.id}")
tasks = self.prepare_tasks(ls, tasks)
# load YOLO model
# TODO: use get_all_classes_inherited_LabelStudioMLBase to detect model classes
model = YOLO(project_id=str(project.id), label_config=project.label_config)
logger.info(f"YOLO ML backend is created")
# predict and send prediction to Label Studio
for task in tqdm(tasks, desc="Predict tasks"):
response = model.predict([task])
predictions = self.postprocess_response(model, response, task)
# send predictions to Label Studio
for prediction in predictions:
ls.predictions.create(
task=task["id"],
score=prediction.get("score", 0),
model_version=prediction.get("model_version", "none"),
result=prediction["result"],
)
logger.info("Model predictions are done!")
@staticmethod
def postprocess_response(model, response, task):
if response is None:
logger.warning(f"No predictions for task: {task}")
return None
# model returned ModelResponse
if isinstance(response, ModelResponse):
# check model version
if not response.has_model_version():
if model.model_version:
response.set_version(str(model.model_version))
else:
response.update_predictions_version()
response = response.model_dump()
predictions = response.get("predictions")
# model returned list of dicts with predictions (old format)
elif isinstance(response, list):
predictions = response
else:
logger.error("No predictions generated by model")
return None
return predictions
@staticmethod
def prepare_tasks(ls, tasks):
# get tasks
if os.path.exists(tasks):
with open(tasks) as f:
tasks = json.load(f)
else:
tasks = tasks.split(",")
tasks = [int(task) for task in tasks]
assert isinstance(tasks, list), "Tasks should be a list"
assert len(tasks) > 0, "'Task list can't be empty"
logger.info(f"Detected {len(tasks)} tasks")
# check task data
if isinstance(tasks[0], dict):
if "data" not in tasks[0] or "id" not in tasks[0]:
raise ValueError("'data' and 'id' must be presented in all tasks")
elif isinstance(tasks[0], int):
# load tasks from Label Studio instance using SDK
logger.info("Task loading from Label Studio instance ...")
tasks = [
{"id": task_id, "data": ls.tasks.get(task_id).data}
for task_id in tqdm(tasks)
]
logger.info("Task loading finished")
else:
raise ValueError(
"Unknown task format: "
"tasks should be a list of dicts (task data) or a list of task ids"
)
return tasks
if __name__ == "__main__":
args = arg_parser()
predictor = LabelStudioMLPredictor(args.ls_url, args.ls_api_key)
predictor.run(args.project, args.tasks)
|