Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from transformers import AutoModel, AutoProcessor | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class QualityEvaluator: | |
| """Image quality assessment using multiple SOTA models""" | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.processors = {} | |
| self.load_models() | |
| def load_models(self): | |
| """Load quality assessment models""" | |
| try: | |
| # Load LAR-IQA model (primary) | |
| logger.info("Loading LAR-IQA model...") | |
| self.load_lar_iqa() | |
| # Load DGIQA model (secondary) | |
| logger.info("Loading DGIQA model...") | |
| self.load_dgiqa() | |
| # Load traditional metrics as fallback | |
| logger.info("Loading traditional quality metrics...") | |
| self.load_traditional_metrics() | |
| except Exception as e: | |
| logger.error(f"Error loading quality models: {str(e)}") | |
| # Use fallback implementation | |
| self.use_fallback_implementation() | |
| def load_lar_iqa(self): | |
| """Load LAR-IQA model""" | |
| try: | |
| # For now, use a placeholder implementation | |
| # In production, this would load the actual LAR-IQA model | |
| self.models['lar_iqa'] = self.create_mock_model() | |
| self.processors['lar_iqa'] = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"Could not load LAR-IQA: {str(e)}") | |
| def load_dgiqa(self): | |
| """Load DGIQA model""" | |
| try: | |
| # Placeholder implementation | |
| self.models['dgiqa'] = self.create_mock_model() | |
| self.processors['dgiqa'] = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| except Exception as e: | |
| logger.warning(f"Could not load DGIQA: {str(e)}") | |
| def load_traditional_metrics(self): | |
| """Load traditional quality metrics (BRISQUE, NIQE, etc.)""" | |
| try: | |
| # These would be implemented using scikit-image or opencv | |
| self.traditional_metrics_available = True | |
| except Exception as e: | |
| logger.warning(f"Could not load traditional metrics: {str(e)}") | |
| self.traditional_metrics_available = False | |
| def create_mock_model(self): | |
| """Create a mock model for demonstration purposes""" | |
| class MockQualityModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = torch.nn.Sequential( | |
| torch.nn.Conv2d(3, 64, 3, padding=1), | |
| torch.nn.ReLU(), | |
| torch.nn.AdaptiveAvgPool2d((1, 1)), | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(64, 1), | |
| torch.nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) * 10 # Scale to 0-10 | |
| model = MockQualityModel().to(self.device) | |
| model.eval() | |
| return model | |
| def use_fallback_implementation(self): | |
| """Use simple fallback quality assessment""" | |
| logger.info("Using fallback quality assessment implementation") | |
| self.fallback_mode = True | |
| def evaluate_with_lar_iqa(self, image: Image.Image) -> float: | |
| """Evaluate image quality using LAR-IQA""" | |
| try: | |
| if 'lar_iqa' not in self.models: | |
| return self.fallback_quality_score(image) | |
| # Preprocess image | |
| tensor = self.processors['lar_iqa'](image).unsqueeze(0).to(self.device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| score = self.models['lar_iqa'](tensor).item() | |
| return max(0.0, min(10.0, score)) | |
| except Exception as e: | |
| logger.error(f"Error in LAR-IQA evaluation: {str(e)}") | |
| return self.fallback_quality_score(image) | |
| def evaluate_with_dgiqa(self, image: Image.Image) -> float: | |
| """Evaluate image quality using DGIQA""" | |
| try: | |
| if 'dgiqa' not in self.models: | |
| return self.fallback_quality_score(image) | |
| # Preprocess image | |
| tensor = self.processors['dgiqa'](image).unsqueeze(0).to(self.device) | |
| # Get prediction | |
| with torch.no_grad(): | |
| score = self.models['dgiqa'](tensor).item() | |
| return max(0.0, min(10.0, score)) | |
| except Exception as e: | |
| logger.error(f"Error in DGIQA evaluation: {str(e)}") | |
| return self.fallback_quality_score(image) | |
| def evaluate_traditional_metrics(self, image: Image.Image) -> float: | |
| """Evaluate using traditional quality metrics""" | |
| try: | |
| # Convert to numpy array | |
| img_array = np.array(image) | |
| # Simple quality metrics based on image statistics | |
| # In production, this would use BRISQUE, NIQE, etc. | |
| # Calculate sharpness (Laplacian variance) | |
| from scipy import ndimage | |
| gray = np.dot(img_array[...,:3], [0.2989, 0.5870, 0.1140]) | |
| laplacian_var = ndimage.laplace(gray).var() | |
| sharpness_score = min(10.0, laplacian_var / 100.0) | |
| # Calculate contrast | |
| contrast_score = min(10.0, gray.std() / 25.0) | |
| # Calculate brightness distribution | |
| brightness_score = 10.0 - abs(gray.mean() - 127.5) / 12.75 | |
| # Combine scores | |
| quality_score = (sharpness_score * 0.4 + | |
| contrast_score * 0.3 + | |
| brightness_score * 0.3) | |
| return max(0.0, min(10.0, quality_score)) | |
| except Exception as e: | |
| logger.error(f"Error in traditional metrics: {str(e)}") | |
| return 5.0 # Default score | |
| def fallback_quality_score(self, image: Image.Image) -> float: | |
| """Simple fallback quality assessment""" | |
| try: | |
| # Basic quality assessment based on image properties | |
| width, height = image.size | |
| # Resolution score | |
| total_pixels = width * height | |
| resolution_score = min(10.0, total_pixels / 100000.0) # Normalize by 1MP | |
| # Aspect ratio score (prefer standard ratios) | |
| aspect_ratio = width / height | |
| if 0.5 <= aspect_ratio <= 2.0: | |
| aspect_score = 8.0 | |
| else: | |
| aspect_score = 5.0 | |
| # File format score (prefer lossless) | |
| format_score = 8.0 if image.format == 'PNG' else 6.0 | |
| # Combine scores | |
| quality_score = (resolution_score * 0.5 + | |
| aspect_score * 0.3 + | |
| format_score * 0.2) | |
| return max(0.0, min(10.0, quality_score)) | |
| except Exception: | |
| return 5.0 # Default neutral score | |
| def evaluate(self, image: Image.Image, anime_mode: bool = False) -> float: | |
| """ | |
| Evaluate image quality using ensemble of models | |
| Args: | |
| image: PIL Image to evaluate | |
| anime_mode: Whether to use anime-specific evaluation | |
| Returns: | |
| Quality score from 0-10 | |
| """ | |
| try: | |
| scores = [] | |
| # LAR-IQA evaluation | |
| lar_score = self.evaluate_with_lar_iqa(image) | |
| scores.append(lar_score) | |
| # DGIQA evaluation | |
| dgiqa_score = self.evaluate_with_dgiqa(image) | |
| scores.append(dgiqa_score) | |
| # Traditional metrics | |
| traditional_score = self.evaluate_traditional_metrics(image) | |
| scores.append(traditional_score) | |
| # Ensemble scoring | |
| if anime_mode: | |
| # For anime images, weight traditional metrics higher | |
| # as they may be more reliable for stylized content | |
| weights = [0.3, 0.3, 0.4] | |
| else: | |
| # For realistic images, weight modern models higher | |
| weights = [0.4, 0.4, 0.2] | |
| final_score = sum(score * weight for score, weight in zip(scores, weights)) | |
| logger.info(f"Quality scores - LAR: {lar_score:.2f}, DGIQA: {dgiqa_score:.2f}, " | |
| f"Traditional: {traditional_score:.2f}, Final: {final_score:.2f}") | |
| return max(0.0, min(10.0, final_score)) | |
| except Exception as e: | |
| logger.error(f"Error in quality evaluation: {str(e)}") | |
| return self.fallback_quality_score(image) | |