Spaces:
Paused
Paused
| import os | |
| import logging | |
| import json | |
| from tqdm import tqdm | |
| from argparse import ArgumentParser | |
| from model import YOLO | |
| from label_studio_sdk.client import LabelStudio | |
| from label_studio_ml.response import ModelResponse | |
| LABEL_STUDIO_URL = os.getenv("LABEL_STUDIO_URL", "http://localhost:8080") | |
| LABEL_STUDIO_API_KEY = os.getenv("LABEL_STUDIO_API_KEY", "your_api_key") | |
| PROJECT_ID = os.getenv("LABEL_STUDIO_PROJECT_ID", "1") | |
| logger = logging.getLogger(__name__) | |
| def arg_parser(): | |
| parser = ArgumentParser(description="YOLO client for Label Studio ML Backend") | |
| parser.add_argument( | |
| "--ls-url", type=str, default=LABEL_STUDIO_URL, help="Label Studio URL" | |
| ) | |
| parser.add_argument( | |
| "--ls-api-key", | |
| type=str, | |
| default=LABEL_STUDIO_API_KEY, | |
| help="Label Studio API Key", | |
| ) | |
| parser.add_argument( | |
| "--project", type=str, default="1", help="Label Studio Project ID" | |
| ) | |
| parser.add_argument( | |
| "--tasks", | |
| type=str, | |
| default="tasks.json", | |
| help="Path to tasks JSON file with list of ids or task datas. Example: tasks.json\n" | |
| "String with ids separated by comma: if you provide task ids, " | |
| "task data will be downloaded automatically from the Label Studio instance. Example: 1,2,3", | |
| ) | |
| return parser.parse_args() | |
| class LabelStudioMLPredictor: | |
| def __init__(self, ls_url, ls_api_key): | |
| self.ls = LabelStudio(base_url=ls_url, api_key=ls_api_key) | |
| logger.info(f"Successfully connected to Label Studio: {ls_url}") | |
| def run(self, project, tasks): | |
| # initialize Label Studio SDK client | |
| ls = self.ls | |
| project = ls.projects.get(id=project) | |
| logger.info(f"Project is retrieved: {project.id}") | |
| tasks = self.prepare_tasks(ls, tasks) | |
| # load YOLO model | |
| # TODO: use get_all_classes_inherited_LabelStudioMLBase to detect model classes | |
| model = YOLO(project_id=str(project.id), label_config=project.label_config) | |
| logger.info(f"YOLO ML backend is created") | |
| # predict and send prediction to Label Studio | |
| for task in tqdm(tasks, desc="Predict tasks"): | |
| response = model.predict([task]) | |
| predictions = self.postprocess_response(model, response, task) | |
| # send predictions to Label Studio | |
| for prediction in predictions: | |
| ls.predictions.create( | |
| task=task["id"], | |
| score=prediction.get("score", 0), | |
| model_version=prediction.get("model_version", "none"), | |
| result=prediction["result"], | |
| ) | |
| logger.info("Model predictions are done!") | |
| def postprocess_response(model, response, task): | |
| if response is None: | |
| logger.warning(f"No predictions for task: {task}") | |
| return None | |
| # model returned ModelResponse | |
| if isinstance(response, ModelResponse): | |
| # check model version | |
| if not response.has_model_version(): | |
| if model.model_version: | |
| response.set_version(str(model.model_version)) | |
| else: | |
| response.update_predictions_version() | |
| response = response.model_dump() | |
| predictions = response.get("predictions") | |
| # model returned list of dicts with predictions (old format) | |
| elif isinstance(response, list): | |
| predictions = response | |
| else: | |
| logger.error("No predictions generated by model") | |
| return None | |
| return predictions | |
| def prepare_tasks(ls, tasks): | |
| # get tasks | |
| if os.path.exists(tasks): | |
| with open(tasks) as f: | |
| tasks = json.load(f) | |
| else: | |
| tasks = tasks.split(",") | |
| tasks = [int(task) for task in tasks] | |
| assert isinstance(tasks, list), "Tasks should be a list" | |
| assert len(tasks) > 0, "'Task list can't be empty" | |
| logger.info(f"Detected {len(tasks)} tasks") | |
| # check task data | |
| if isinstance(tasks[0], dict): | |
| if "data" not in tasks[0] or "id" not in tasks[0]: | |
| raise ValueError("'data' and 'id' must be presented in all tasks") | |
| elif isinstance(tasks[0], int): | |
| # load tasks from Label Studio instance using SDK | |
| logger.info("Task loading from Label Studio instance ...") | |
| tasks = [ | |
| {"id": task_id, "data": ls.tasks.get(task_id).data} | |
| for task_id in tqdm(tasks) | |
| ] | |
| logger.info("Task loading finished") | |
| else: | |
| raise ValueError( | |
| "Unknown task format: " | |
| "tasks should be a list of dicts (task data) or a list of task ids" | |
| ) | |
| return tasks | |
| if __name__ == "__main__": | |
| args = arg_parser() | |
| predictor = LabelStudioMLPredictor(args.ls_url, args.ls_api_key) | |
| predictor.run(args.project, args.tasks) | |