import os import logging import json from tqdm import tqdm from argparse import ArgumentParser from model import YOLO from label_studio_sdk.client import LabelStudio from label_studio_ml.response import ModelResponse LABEL_STUDIO_URL = os.getenv("LABEL_STUDIO_URL", "http://localhost:8080") LABEL_STUDIO_API_KEY = os.getenv("LABEL_STUDIO_API_KEY", "your_api_key") PROJECT_ID = os.getenv("LABEL_STUDIO_PROJECT_ID", "1") logger = logging.getLogger(__name__) def arg_parser(): parser = ArgumentParser(description="YOLO client for Label Studio ML Backend") parser.add_argument( "--ls-url", type=str, default=LABEL_STUDIO_URL, help="Label Studio URL" ) parser.add_argument( "--ls-api-key", type=str, default=LABEL_STUDIO_API_KEY, help="Label Studio API Key", ) parser.add_argument( "--project", type=str, default="1", help="Label Studio Project ID" ) parser.add_argument( "--tasks", type=str, default="tasks.json", help="Path to tasks JSON file with list of ids or task datas. Example: tasks.json\n" "String with ids separated by comma: if you provide task ids, " "task data will be downloaded automatically from the Label Studio instance. Example: 1,2,3", ) return parser.parse_args() class LabelStudioMLPredictor: def __init__(self, ls_url, ls_api_key): self.ls = LabelStudio(base_url=ls_url, api_key=ls_api_key) logger.info(f"Successfully connected to Label Studio: {ls_url}") def run(self, project, tasks): # initialize Label Studio SDK client ls = self.ls project = ls.projects.get(id=project) logger.info(f"Project is retrieved: {project.id}") tasks = self.prepare_tasks(ls, tasks) # load YOLO model # TODO: use get_all_classes_inherited_LabelStudioMLBase to detect model classes model = YOLO(project_id=str(project.id), label_config=project.label_config) logger.info(f"YOLO ML backend is created") # predict and send prediction to Label Studio for task in tqdm(tasks, desc="Predict tasks"): response = model.predict([task]) predictions = self.postprocess_response(model, response, task) # send predictions to Label Studio for prediction in predictions: ls.predictions.create( task=task["id"], score=prediction.get("score", 0), model_version=prediction.get("model_version", "none"), result=prediction["result"], ) logger.info("Model predictions are done!") @staticmethod def postprocess_response(model, response, task): if response is None: logger.warning(f"No predictions for task: {task}") return None # model returned ModelResponse if isinstance(response, ModelResponse): # check model version if not response.has_model_version(): if model.model_version: response.set_version(str(model.model_version)) else: response.update_predictions_version() response = response.model_dump() predictions = response.get("predictions") # model returned list of dicts with predictions (old format) elif isinstance(response, list): predictions = response else: logger.error("No predictions generated by model") return None return predictions @staticmethod def prepare_tasks(ls, tasks): # get tasks if os.path.exists(tasks): with open(tasks) as f: tasks = json.load(f) else: tasks = tasks.split(",") tasks = [int(task) for task in tasks] assert isinstance(tasks, list), "Tasks should be a list" assert len(tasks) > 0, "'Task list can't be empty" logger.info(f"Detected {len(tasks)} tasks") # check task data if isinstance(tasks[0], dict): if "data" not in tasks[0] or "id" not in tasks[0]: raise ValueError("'data' and 'id' must be presented in all tasks") elif isinstance(tasks[0], int): # load tasks from Label Studio instance using SDK logger.info("Task loading from Label Studio instance ...") tasks = [ {"id": task_id, "data": ls.tasks.get(task_id).data} for task_id in tqdm(tasks) ] logger.info("Task loading finished") else: raise ValueError( "Unknown task format: " "tasks should be a list of dicts (task data) or a list of task ids" ) return tasks if __name__ == "__main__": args = arg_parser() predictor = LabelStudioMLPredictor(args.ls_url, args.ls_api_key) predictor.run(args.project, args.tasks)