| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from itertools import islice |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
| from torchvision.transforms import transforms |
|
|
| from models import RRDBNet |
|
|
| REPO_ID = "kadirnar/BSRGANx2" |
|
|
| pretrain_model_path = hf_hub_download(repo_id=REPO_ID, filename="BSRGANx2.pth") |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) |
| model.load_state_dict(torch.load(pretrain_model_path), strict=True) |
| model.eval() |
|
|
| for k, v in model.named_parameters(): |
| v.requires_grad = False |
|
|
| model = model.to(device) |
|
|
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| ]) |
|
|
|
|
| def predict(image): |
| """ |
| Enhances the image face. |
| Parameters: |
| image (string): File path to the input image. |
| Returns: |
| image (string): paths for image enhanced. |
| """ |
|
|
| tensor = transform(image).unsqueeze(0).to(device) |
| tensor = model(tensor) |
| tensor = tensor.detach().squeeze().float().clamp(0, 1).cpu() |
| result = tensor.numpy() |
|
|
| if result.ndim == 3: |
| result = np.transpose(result, (1, 2, 0)) |
|
|
| return image, (result * 255.0).round().astype(np.uint8) |
|
|
|
|
| with gr.Blocks(title="BSRGAN") as app: |
| navbar = gr.Navbar(visible=True, main_page_name="Workspace") |
| gr.Markdown("## BSRGANx2") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| with gr.Row(): |
| source_image = gr.Image(type="numpy", label="Image") |
| image_btn = gr.Button("Enhance image") |
| with gr.Column(scale=1): |
| with gr.Row(): |
| output_image = gr.ImageSlider(label="Enhanced image", type="filepath") |
| |
|
|
| image_btn.click(fn=predict, inputs=[source_image], outputs=output_image) |
|
|
| with app.route("Readme", "/readme"): |
| with open("README.md") as f: |
| for line in islice(f, 15, None): |
| gr.Markdown(line.strip()) |
|
|
| app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True) |
| app.queue() |
|
|