Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from setup import setup | |
| import torch | |
| import gc | |
| from PIL import Image | |
| from transformers import AutoModel, AutoImageProcessor | |
| from anime2sketch.model import Anime2Sketch | |
| import spaces | |
| setup() | |
| print("Setup finished") | |
| MLE_MODEL_REPO = "p1atdev/MangaLineExtraction-hf" | |
| class MangaLineExtractor: | |
| model = AutoModel.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True) | |
| processor = AutoImageProcessor.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True) | |
| def __call__(self, image: Image.Image) -> Image.Image: | |
| inputs = self.processor(image, return_tensors="pt") | |
| outputs = self.model(inputs.pixel_values) | |
| line_image = Image.fromarray(outputs.pixel_values[0].numpy().astype("uint8"), mode="L") | |
| return line_image | |
| mle_model = MangaLineExtractor() | |
| a2s_model = Anime2Sketch("./models/netG.pth", "cpu") | |
| def flush(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def extract(image): | |
| result = mle_model(image) | |
| return result | |
| def convert_to_sketch(image): | |
| result = a2s_model.predict(image) | |
| return result | |
| def start(image): | |
| return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))] | |
| def clear(): | |
| return [None, None] | |
| def ui(): | |
| with gr.Blocks() as blocks: | |
| gr.Markdown( | |
| """ | |
| # Anime to Sketch | |
| Unofficial demo for converting illustrations into sketches. | |
| Original repos: | |
| - [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch) | |
| - [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch) | |
| Using with 🤗 transformers: | |
| - [MangaLineExtraction-hf](https://huggingface.co/p1atdev/MangaLineExtraction-hf) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.Image(label="Input", interactive=True) | |
| extract_btn = gr.Button("Start", variant="primary") | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| with gr.Column(): | |
| # with gr.Row(): | |
| extract_output_img = gr.Image( | |
| label="MangaLineExtraction", interactive=False | |
| ) | |
| to_sketch_output_img = gr.Image(label="Anime2Sketch", interactive=False) | |
| gr.Examples( | |
| fn=start, | |
| examples=[ | |
| ["./examples/0.jpg"], | |
| ["./examples/1.jpg"], | |
| ["./examples/2.jpg"], | |
| ], | |
| inputs=[input_img], | |
| outputs=[extract_output_img, to_sketch_output_img], | |
| label="Examples", | |
| # cache_examples=True, | |
| ) | |
| gr.Markdown("Images are from nijijourney.") | |
| extract_btn.click( | |
| fn=start, | |
| inputs=[input_img], | |
| outputs=[extract_output_img, to_sketch_output_img], | |
| ) | |
| clear_btn.click( | |
| fn=clear, | |
| inputs=[], | |
| outputs=[extract_output_img, to_sketch_output_img], | |
| ) | |
| return blocks | |
| if __name__ == "__main__": | |
| ui().launch() | |