ls-yolo-backend / cli.py
davanstrien's picture
davanstrien HF Staff
Initial: HumanSignal yolo example patched for HF Spaces
3f7dd83 verified
import os
import logging
import json
from tqdm import tqdm
from argparse import ArgumentParser
from model import YOLO
from label_studio_sdk.client import LabelStudio
from label_studio_ml.response import ModelResponse
LABEL_STUDIO_URL = os.getenv("LABEL_STUDIO_URL", "http://localhost:8080")
LABEL_STUDIO_API_KEY = os.getenv("LABEL_STUDIO_API_KEY", "your_api_key")
PROJECT_ID = os.getenv("LABEL_STUDIO_PROJECT_ID", "1")
logger = logging.getLogger(__name__)
def arg_parser():
parser = ArgumentParser(description="YOLO client for Label Studio ML Backend")
parser.add_argument(
"--ls-url", type=str, default=LABEL_STUDIO_URL, help="Label Studio URL"
)
parser.add_argument(
"--ls-api-key",
type=str,
default=LABEL_STUDIO_API_KEY,
help="Label Studio API Key",
)
parser.add_argument(
"--project", type=str, default="1", help="Label Studio Project ID"
)
parser.add_argument(
"--tasks",
type=str,
default="tasks.json",
help="Path to tasks JSON file with list of ids or task datas. Example: tasks.json\n"
"String with ids separated by comma: if you provide task ids, "
"task data will be downloaded automatically from the Label Studio instance. Example: 1,2,3",
)
return parser.parse_args()
class LabelStudioMLPredictor:
def __init__(self, ls_url, ls_api_key):
self.ls = LabelStudio(base_url=ls_url, api_key=ls_api_key)
logger.info(f"Successfully connected to Label Studio: {ls_url}")
def run(self, project, tasks):
# initialize Label Studio SDK client
ls = self.ls
project = ls.projects.get(id=project)
logger.info(f"Project is retrieved: {project.id}")
tasks = self.prepare_tasks(ls, tasks)
# load YOLO model
# TODO: use get_all_classes_inherited_LabelStudioMLBase to detect model classes
model = YOLO(project_id=str(project.id), label_config=project.label_config)
logger.info(f"YOLO ML backend is created")
# predict and send prediction to Label Studio
for task in tqdm(tasks, desc="Predict tasks"):
response = model.predict([task])
predictions = self.postprocess_response(model, response, task)
# send predictions to Label Studio
for prediction in predictions:
ls.predictions.create(
task=task["id"],
score=prediction.get("score", 0),
model_version=prediction.get("model_version", "none"),
result=prediction["result"],
)
logger.info("Model predictions are done!")
@staticmethod
def postprocess_response(model, response, task):
if response is None:
logger.warning(f"No predictions for task: {task}")
return None
# model returned ModelResponse
if isinstance(response, ModelResponse):
# check model version
if not response.has_model_version():
if model.model_version:
response.set_version(str(model.model_version))
else:
response.update_predictions_version()
response = response.model_dump()
predictions = response.get("predictions")
# model returned list of dicts with predictions (old format)
elif isinstance(response, list):
predictions = response
else:
logger.error("No predictions generated by model")
return None
return predictions
@staticmethod
def prepare_tasks(ls, tasks):
# get tasks
if os.path.exists(tasks):
with open(tasks) as f:
tasks = json.load(f)
else:
tasks = tasks.split(",")
tasks = [int(task) for task in tasks]
assert isinstance(tasks, list), "Tasks should be a list"
assert len(tasks) > 0, "'Task list can't be empty"
logger.info(f"Detected {len(tasks)} tasks")
# check task data
if isinstance(tasks[0], dict):
if "data" not in tasks[0] or "id" not in tasks[0]:
raise ValueError("'data' and 'id' must be presented in all tasks")
elif isinstance(tasks[0], int):
# load tasks from Label Studio instance using SDK
logger.info("Task loading from Label Studio instance ...")
tasks = [
{"id": task_id, "data": ls.tasks.get(task_id).data}
for task_id in tqdm(tasks)
]
logger.info("Task loading finished")
else:
raise ValueError(
"Unknown task format: "
"tasks should be a list of dicts (task data) or a list of task ids"
)
return tasks
if __name__ == "__main__":
args = arg_parser()
predictor = LabelStudioMLPredictor(args.ls_url, args.ls_api_key)
predictor.run(args.project, args.tasks)