Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from DISTS_pytorch import DISTS | |
| from torchvision.io import read_image | |
| import torch | |
| import torchvision.transforms.v2 as transforms | |
| import spaces | |
| from metrics.DeepDC import DeepDC | |
| from metrics.DeepWSD import DeepWSD | |
| from metrics.ADISTS import ADISTS | |
| from dreamsim import dreamsim | |
| # pyiqa requires older version of packages, causing dependency issues during install. Therefore, we install it here. | |
| # Specifically, it requires transformers=4.37.2. | |
| try: | |
| import pyiqa | |
| except ImportError: | |
| print("pyiqa not found. Installing...") | |
| import subprocess | |
| import sys | |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "pyiqa==0.1.14.1", "--no-deps"]) | |
| import pyiqa | |
| # Download models once at startup | |
| _, _ = dreamsim(pretrained=True, device="cpu") | |
| class Evaluator: | |
| def __init__(self, device): | |
| self.device = device | |
| self.transform = transforms.ToDtype(dtype=torch.float32, scale=True) | |
| self.metrics = self._init_metrics() | |
| def _init_metrics(self): | |
| return { | |
| "β MSE": torch.nn.functional.mse_loss, | |
| "β L1": torch.nn.functional.l1_loss, | |
| "β DISTS": DISTS().to(self.device), | |
| "β ADISTS": ADISTS().to(self.device), | |
| "β DeepDC": DeepDC().to(self.device), | |
| "β DeepWSD": DeepWSD().to(self.device), | |
| "β LPIPS": pyiqa.create_metric("lpips", device=self.device), | |
| "β DreamSim": dreamsim(pretrained=True, device=self.device)[0], | |
| "β PSNR": pyiqa.create_metric("psnr", device=self.device), | |
| "β SSIM": pyiqa.create_metric("ssim", device=self.device), | |
| "β MS-SSIM": pyiqa.create_metric("ms_ssim", device=self.device), | |
| "β CW-SSIM": pyiqa.create_metric("cw_ssim", device=self.device), | |
| "β FSIM": pyiqa.create_metric("fsim", device=self.device), | |
| } | |
| def evaluate(self, img_fname1, img_fname2): | |
| img1 = self.transform(read_image(img_fname1)).unsqueeze(0).to(self.device) | |
| img2 = self.transform(read_image(img_fname2)).unsqueeze(0).to(self.device) | |
| # check images are the same size | |
| if img1.shape != img2.shape: | |
| return "Input images must have the same dimensions!" | |
| return "\n".join( | |
| f"{name:<10}: {float(metric(img1, img2).item()):3,.5f}" | |
| for name, metric in self.metrics.items() | |
| ) | |
| def get_evaluator(): | |
| """Returns a singleton Evaluator instance per worker/session.""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if not hasattr(get_evaluator, "evaluator"): | |
| get_evaluator.evaluator = Evaluator(device) | |
| return get_evaluator.evaluator | |
| def compute_similarity(img1_path, img2_path): | |
| """Main function for Gradio interface.""" | |
| if not img1_path or not img2_path: | |
| return "Please upload both images!" | |
| return get_evaluator().evaluate(img1_path, img2_path) | |
| def create_interface(): | |
| examples = [ | |
| ["examples/01_1.jpg", "examples/01_1.jpg"], # Add an extra example for identical images | |
| ["examples/01_1.jpg", "examples/noise.jpg"], | |
| ["examples/00_1.jpg", "examples/00_1_rotated.png"], | |
| *[[f"examples/{i:02d}_1.jpg", f"examples/{i:02d}_2.jpg"] for i in range(1, 10)], | |
| ] | |
| # Custom CSS | |
| css = """ | |
| .center-header { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| margin: 0 0 10px 0; | |
| } | |
| .monospace-text { | |
| font-family: 'Courier New', Courier, monospace; | |
| } | |
| .metrics-table { | |
| width: 100%; | |
| border-collapse: collapse; | |
| } | |
| .metrics-table td { | |
| padding: 10px; | |
| vertical-align: top; | |
| } | |
| """ | |
| # Add UI elements | |
| pyiqa_url = "https://github.com/chaofengc/IQA-PyTorch" | |
| with gr.Blocks(title="FR-IQA", css=css) as demo: | |
| gr.Markdown(f""" | |
| <div class='center-header'><h1>Full-Reference Image Quality Assessment</h1></div> | |
| Upload two images to compute various similarity metrics.<br> | |
| **Note**: Images must have identical dimensions. Code will run much faster locally: due to ZeroGPU setup, metrics are re-initialized on every run.. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| img_fname1 = gr.Image(type="filepath", label="Image#1", height=512, width=512) | |
| with gr.Column(scale=2): | |
| img_fname2 = gr.Image(type="filepath", label="Image#2", height=512, width=512) | |
| with gr.Column(scale=1): | |
| metrics_output = gr.Textbox(label="Metrics Output", lines=22, elem_classes="monospace-text", show_copy_button=True) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Compute Metrics", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[img_fname1, img_fname2], | |
| fn=compute_similarity, | |
| outputs=metrics_output, | |
| label="Example Pairs (all are 1024Γ768)", | |
| cache_examples=True, | |
| cache_mode="lazy", | |
| examples_per_page=6 | |
| ) | |
| with gr.Column(scale=2): | |
| gr.Markdown(f""" | |
| <div class='center-header'><h3>Acknowledgements</h3></div> | |
| - Example images from [TryOffDiff](https://rizavelioglu.github.io/tryoffdiff) paper, which are sampled from VITON-HD dataset. | |
| - Metrics (*score range is only rough estimation, actual score range may vary*): | |
| <table class="metrics-table"> | |
| <tr> | |
| <th>Metric</th> | |
| <th>Score Range</th> | |
| <th>Lower is better?</th> | |
| <th>Source</th> | |
| </tr> | |
| <tr> | |
| <td>MSE</td> | |
| <td>[0, β)</td> | |
| <td>Yes</td> | |
| <td><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.MSELoss.html">torch</a></td> | |
| </tr> | |
| <tr> | |
| <td>L1</td> | |
| <td>[0, β)</td> | |
| <td>Yes</td> | |
| <td><a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.L1Loss.html">torch</a></td> | |
| </tr> | |
| <tr> | |
| <td>DISTS</td> | |
| <td>[0, 1]</td> | |
| <td>Yes</td> | |
| <td><a href="https://github.com/dingkeyan93/DISTS">official</a></td> | |
| </tr> | |
| <tr> | |
| <td>ADISTS</td> | |
| <td>~[0, 1]</td> | |
| <td>Yes</td> | |
| <td><a href="https://github.com/dingkeyan93/A-DISTS">official</a></td> | |
| </tr> | |
| <tr> | |
| <td>DeepDC</td> | |
| <td>[0, 1]</td> | |
| <td>Yes</td> | |
| <td><a href="https://github.com/h4nwei/DeepDC">official</a></td> | |
| </tr> | |
| <tr> | |
| <td>DeepWSD</td> | |
| <td>[0, β)</td> | |
| <td>Yes</td> | |
| <td><a href="https://github.com/Buka-Xing/DeepWSD">official</a></td> | |
| </tr> | |
| <tr> | |
| <td>LPIPS</td> | |
| <td>[0, 1]</td> | |
| <td>Yes</td> | |
| <td><a href="{pyiqa_url}">pyiqa</a></td> | |
| </tr> | |
| <tr> | |
| <td>DreamSim</td> | |
| <td>[0, 1]</td> | |
| <td>Yes</td> | |
| <td><a href="https://github.com/ssundaram21/dreamsim">official</a></td> | |
| </tr> | |
| <tr> | |
| <td>PSNR</td> | |
| <td>[0, β)</td> | |
| <td>No</td> | |
| <td><a href="{pyiqa_url}">pyiqa</a></td> | |
| </tr> | |
| <tr> | |
| <td>SSIM</td> | |
| <td>[0, 1]</td> | |
| <td>No</td> | |
| <td><a href="{pyiqa_url}">pyiqa</a></td> | |
| </tr> | |
| <tr> | |
| <td>MS-SSIM</td> | |
| <td>[0, 1]</td> | |
| <td>No</td> | |
| <td><a href="{pyiqa_url}">pyiqa</a></td> | |
| </tr> | |
| <tr> | |
| <td>CW-SSIM</td> | |
| <td>[0, 1]</td> | |
| <td>No</td> | |
| <td><a href="{pyiqa_url}">pyiqa</a></td> | |
| </tr> | |
| <tr> | |
| <td>FSIM</td> | |
| <td>[0, 1]</td> | |
| <td>No</td> | |
| <td><a href="{pyiqa_url}">pyiqa</a></td> | |
| </tr> | |
| </table> | |
| """) | |
| submit_btn.click( | |
| fn=compute_similarity, | |
| inputs=[img_fname1, img_fname2], | |
| outputs=[metrics_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=False, ssr_mode=False) | |