import gradio as gr from DISTS_pytorch import DISTS from torchvision.io import read_image import torch import torchvision.transforms.v2 as transforms import spaces from metrics.DeepDC import DeepDC from metrics.DeepWSD import DeepWSD from metrics.ADISTS import ADISTS from dreamsim import dreamsim # pyiqa requires older version of packages, causing dependency issues during install. Therefore, we install it here. # Specifically, it requires transformers=4.37.2. try: import pyiqa except ImportError: print("pyiqa not found. Installing...") import subprocess import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "pyiqa==0.1.14.1", "--no-deps"]) import pyiqa # Download models once at startup _, _ = dreamsim(pretrained=True, device="cpu") @spaces.GPU(duration=15) class Evaluator: def __init__(self, device): self.device = device self.transform = transforms.ToDtype(dtype=torch.float32, scale=True) self.metrics = self._init_metrics() def _init_metrics(self): return { "↓ MSE": torch.nn.functional.mse_loss, "↓ L1": torch.nn.functional.l1_loss, "↓ DISTS": DISTS().to(self.device), "↓ ADISTS": ADISTS().to(self.device), "↓ DeepDC": DeepDC().to(self.device), "↓ DeepWSD": DeepWSD().to(self.device), "↓ LPIPS": pyiqa.create_metric("lpips", device=self.device), "↓ DreamSim": dreamsim(pretrained=True, device=self.device)[0], "↑ PSNR": pyiqa.create_metric("psnr", device=self.device), "↑ SSIM": pyiqa.create_metric("ssim", device=self.device), "↑ MS-SSIM": pyiqa.create_metric("ms_ssim", device=self.device), "↑ CW-SSIM": pyiqa.create_metric("cw_ssim", device=self.device), "↑ FSIM": pyiqa.create_metric("fsim", device=self.device), } @torch.no_grad() def evaluate(self, img_fname1, img_fname2): img1 = self.transform(read_image(img_fname1)).unsqueeze(0).to(self.device) img2 = self.transform(read_image(img_fname2)).unsqueeze(0).to(self.device) # check images are the same size if img1.shape != img2.shape: return "Input images must have the same dimensions!" return "\n".join( f"{name:<10}: {float(metric(img1, img2).item()):3,.5f}" for name, metric in self.metrics.items() ) @spaces.GPU(duration=5) def get_evaluator(): """Returns a singleton Evaluator instance per worker/session.""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not hasattr(get_evaluator, "evaluator"): get_evaluator.evaluator = Evaluator(device) return get_evaluator.evaluator @spaces.GPU(duration=20) def compute_similarity(img1_path, img2_path): """Main function for Gradio interface.""" if not img1_path or not img2_path: return "Please upload both images!" return get_evaluator().evaluate(img1_path, img2_path) def create_interface(): examples = [ ["examples/01_1.jpg", "examples/01_1.jpg"], # Add an extra example for identical images ["examples/01_1.jpg", "examples/noise.jpg"], ["examples/00_1.jpg", "examples/00_1_rotated.png"], *[[f"examples/{i:02d}_1.jpg", f"examples/{i:02d}_2.jpg"] for i in range(1, 10)], ] # Custom CSS css = """ .center-header { display: flex; align-items: center; justify-content: center; margin: 0 0 10px 0; } .monospace-text { font-family: 'Courier New', Courier, monospace; } .metrics-table { width: 100%; border-collapse: collapse; } .metrics-table td { padding: 10px; vertical-align: top; } """ # Add UI elements pyiqa_url = "https://github.com/chaofengc/IQA-PyTorch" with gr.Blocks(title="FR-IQA", css=css) as demo: gr.Markdown(f"""
| Metric | Score Range | Lower is better? | Source |
|---|---|---|---|
| MSE | [0, ∞) | Yes | torch |
| L1 | [0, ∞) | Yes | torch |
| DISTS | [0, 1] | Yes | official |
| ADISTS | ~[0, 1] | Yes | official |
| DeepDC | [0, 1] | Yes | official |
| DeepWSD | [0, ∞) | Yes | official |
| LPIPS | [0, 1] | Yes | pyiqa |
| DreamSim | [0, 1] | Yes | official |
| PSNR | [0, ∞) | No | pyiqa |
| SSIM | [0, 1] | No | pyiqa |
| MS-SSIM | [0, 1] | No | pyiqa |
| CW-SSIM | [0, 1] | No | pyiqa |
| FSIM | [0, 1] | No | pyiqa |