Spaces:
Running on Zero
Running on Zero
| #!/usr/bin/env python | |
| import pathlib | |
| import sys | |
| import tarfile | |
| import cv2 | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import PIL.Image | |
| import spaces | |
| import torch | |
| sys.path.insert(0, "yolov5_anime") | |
| from models.yolo import Model # pyright: ignore[reportMissingImports] | |
| from utils.datasets import letterbox # pyright: ignore[reportMissingImports] | |
| from utils.general import non_max_suppression, scale_coords # pyright: ignore[reportMissingImports] | |
| DESCRIPTION = "# [zymk9/yolov5_anime](https://github.com/zymk9/yolov5_anime)" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| torch.set_grad_enabled(False) | |
| MODEL_REPO = "public-data/yolov5_anime" | |
| model_path = huggingface_hub.hf_hub_download(MODEL_REPO, "yolov5x_anime.pth") | |
| config_path = huggingface_hub.hf_hub_download(MODEL_REPO, "yolov5x.yaml") | |
| state_dict = torch.load(model_path, weights_only=True) | |
| model = Model(cfg=config_path) | |
| model.load_state_dict(state_dict) | |
| if device.type != "cpu": | |
| model.half() | |
| model.to(device) | |
| model.eval() | |
| def load_sample_image_paths() -> list[pathlib.Path]: | |
| image_dir = pathlib.Path("images") | |
| if not image_dir.exists(): | |
| dataset_repo = "hysts/sample-images-TADNE" | |
| path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset") | |
| with tarfile.open(path) as f: | |
| f.extractall() # noqa: S202 | |
| return sorted(image_dir.glob("*")) | |
| def predict(image: PIL.Image.Image, score_threshold: float, iou_threshold: float) -> np.ndarray: | |
| orig_image = np.asarray(image) | |
| image = letterbox(orig_image, new_shape=640)[0] | |
| data = torch.from_numpy(image.transpose(2, 0, 1)).float() / 255 | |
| data = data.to(device).unsqueeze(0) | |
| if device.type != "cpu": | |
| data = data.half() | |
| preds = model(data)[0] | |
| preds = non_max_suppression(preds, score_threshold, iou_threshold) | |
| detections = [] | |
| for pred in preds: | |
| if pred is not None and len(pred) > 0: | |
| pred[:, :4] = scale_coords(data.shape[2:], pred[:, :4], orig_image.shape).round() | |
| # (x0, y0, x1, y0, conf, class) | |
| detections.append(pred.cpu().numpy()) | |
| detections = np.concatenate(detections) if detections else np.empty(shape=(0, 6)) | |
| res = orig_image.copy() | |
| for det in detections: | |
| x0, y0, x1, y1 = det[:4].astype(int) | |
| cv2.rectangle(res, (x0, y0), (x1, y1), (0, 255, 0), 3) | |
| return res | |
| image_paths = load_sample_image_paths() | |
| examples = [[path.as_posix(), 0.4, 0.5] for path in image_paths] | |
| with gr.Blocks() as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(label="Input", type="pil") | |
| score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.05, value=0.4) | |
| iou_threshold = gr.Slider(label="IoU Threshold", minimum=0, maximum=1, step=0.05, value=0.5) | |
| run_button = gr.Button() | |
| with gr.Column(): | |
| result = gr.Image(label="Output") | |
| inputs = [image, score_threshold, iou_threshold] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=inputs, | |
| outputs=result, | |
| fn=predict, | |
| ) | |
| run_button.click( | |
| fn=predict, | |
| inputs=inputs, | |
| outputs=result, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css_paths="style.css") | |