Spaces:
Runtime error
Runtime error
| import random | |
| from PIL import ImageDraw, Image | |
| from huggingface_hub import hf_hub_download | |
| from ultralytics import YOLO | |
| def plot_one_box(x, img, color=None, label=None, line_thickness=None): | |
| """ | |
| Helper Functions for Plotting BBoxes | |
| :param x: | |
| :param img: | |
| :param color: | |
| :param label: | |
| :param line_thickness: | |
| :return: | |
| """ | |
| width, height = img.size | |
| tl = line_thickness or round(0.002 * (width + height) / 2) + 1 # line/font thickness | |
| color = color or (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
| c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) | |
| img_draw = ImageDraw.Draw(img) | |
| img_draw.rectangle((c1[0], c1[1], c2[0], c2[1]), outline=color, width=tl) | |
| if label: | |
| tf = max(tl - 1, 1) # font thickness | |
| x1, y1, x2, y2 = img_draw.textbbox(c1, label, stroke_width=tf) | |
| img_draw.rectangle((x1, y1, x2, y2), fill=color) | |
| img_draw.text((x1, y1), label, fill=(255, 255, 255)) | |
| def add_bboxes(pil_img, result, confidence=0.6): | |
| """ | |
| Plotting Bounding Box on img | |
| :param pil_img: | |
| :param result: | |
| :param confidence: | |
| :return: | |
| """ | |
| for box in result.boxes: | |
| [cl] = box.cls.tolist() | |
| [conf] = box.conf.tolist() | |
| if conf < confidence: | |
| continue | |
| [rect] = box.xyxy.tolist() | |
| text = f'{result.names[cl]}: {conf: 0.2f}' | |
| plot_one_box(x=rect, img=pil_img, label=text) | |
| return pil_img | |
| class YoloModel: | |
| def __init__(self, repo_name: str, file_name: str): | |
| weight_file = YoloModel.download_weight_file(repo_name, file_name) | |
| self.model = YOLO(weight_file) | |
| def download_weight_file(repo_name: str, file_name: str): | |
| return hf_hub_download(repo_name, file_name) | |
| def detect(self, im): | |
| return self.model(source=im) | |
| def preview_detect(self, im, confidence): | |
| results = self.model(source=im) | |
| res_img = im | |
| for result in results: | |
| res_img = add_bboxes(res_img, result, confidence) | |
| return res_img | |
| def test(): | |
| model = YoloModel("SHOU-ISD/fire-and-smoke", "yolov8n.pt") | |
| im = Image.open("./tests/fire1.jpg") | |
| results = model.model(source=im) | |
| for result in results: | |
| im = add_bboxes(im, result, confidence=0.1) | |
| print(result.boxes) | |
| def argument_parser(): | |
| """ | |
| Argument Parser | |
| :return: args | |
| """ | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Help for YoloModel') | |
| parser.add_argument('--test', '-t', action='store_true', help='Run test') | |
| # list of repo_name&file_name | |
| parser.add_argument('--weight_files', '-w', nargs='+', help='List of weight files') | |
| return parser.parse_args() | |
| def pre_cache_weight_files(weight_files: list[str]): | |
| """ | |
| Pre-cache weight files | |
| :return: None | |
| """ | |
| for weight_file in weight_files: | |
| weight_file = weight_file.split(":") | |
| YoloModel.download_weight_file(weight_file[0], weight_file[1]) | |
| if __name__ == '__main__': | |
| args = argument_parser() | |
| if args.test: | |
| test() | |
| else: | |
| pre_cache_weight_files(args.weight_files) | |