import os import time from pathlib import Path from PIL import Image from flask import Flask, request, jsonify # Import model loaders and predictors from .RIDCP.inference import load_ridcp_model, ridcp_predict from .SCUNet.inference import load_scu_model, scu_predict from .Retinexformer.inference import load_retinexformer_model, retinexformer_predict from .img2img_turbo.inference import load_turbo_model, turbo_predict from .ESRGAN.inference import load_esrgan_model, esrgan_predict from .IDT.inference import load_idt_model, idt_predict from .iqa_reward import IQAReward # Configure environment variables os.environ["BASICSR_JIT"] = "True" os.environ["CUDA_VISIBLE_DEVICES"] = "7" # Initialize Flask application app = Flask(__name__) # Global variables models = {} iqa = IQAReward() class ModelTester: """ Model testing service for image restoration models. This class manages model loading, image processing, and quality assessment. """ def __init__(self, output_base_dir="datasets/tmp_result"): """ Initialize the model tester. Args: output_base_dir (str): Base directory for storing results. """ self.output_base_dir = output_base_dir self.models = {} self.iqa = IQAReward() self.model_loaders = { 'scunet': (load_scu_model, scu_predict), 'retinexformer_lolv2': (lambda: load_retinexformer_model('LOLV2'), retinexformer_predict), 'retinexformer_fivek': (lambda: load_retinexformer_model('FiveK'), retinexformer_predict), 'turbo_night': (lambda: load_turbo_model('night'), turbo_predict), 'turbo_rain': (lambda: load_turbo_model('rain'), turbo_predict), 'turbo_snow': (lambda: load_turbo_model('snow'), turbo_predict), 'real_esrgan': (load_esrgan_model, esrgan_predict), 'ridcp': (load_ridcp_model, ridcp_predict), 'idt': (load_idt_model, idt_predict) } def load_models(self, model_names): """ Load specified models into memory. Args: model_names (list): List of model names to load. """ print(f"Loading models: {', '.join(model_names)}") self.models = {} for model_name in model_names: if model_name in self.model_loaders: loader_fn = self.model_loaders[model_name][0] self.models[model_name] = loader_fn() print(f"Loaded {model_name}") else: print(f"Unknown model: {model_name}") print(f"Finished loading {len(self.models)} models") def resize_image(self, img_path, output_dir, target_size=(256, 256)): """ Resize input image to a standard size. Args: img_path (str): Path to the input image. output_dir (str): Directory to save the resized image. target_size (tuple): Target resolution (width, height). Returns: str: Path to the resized image. """ # Create output directory if it doesn't exist os.makedirs(output_dir, exist_ok=True) with Image.open(img_path) as img: # Ensure consistent color mode img = img.convert('RGB') # Use high-quality resampling img = img.resize(target_size, Image.LANCZOS) # Generate output filename img_name = os.path.splitext(os.path.basename(img_path))[0] save_path = os.path.join(output_dir, f"{img_name}.png") # Save the resized image img.save(save_path, format='PNG') return save_path def process_image_with_models(self, model_list, img_path, output_dir): """ Process an image with a sequence of models. Args: model_list (list): List of model names to apply in sequence. img_path (str): Path to the input image. output_dir (str): Directory to save the processed images. Returns: str: Path to the final processed image. """ # Resize input image img_path = self.resize_image(img_path, output_dir) # Apply each model in sequence for model_name in model_list: if model_name not in self.models: print(f"Model {model_name} not loaded, skipping") continue # Get the predict function for this model _, predict_fn = self.model_loaders[model_name] # Process the image with the current model img_path = predict_fn(self.models[model_name], img_path, output_dir) print(f"Applied {model_name}, saved result to {img_path}") return img_path def create_output_dir(self): """ Create a unique output directory based on current timestamp. Returns: str: Path to the created output directory. """ timestamp = int(time.time()) output_dir = os.path.join(self.output_base_dir, f"{timestamp}") os.makedirs(output_dir, exist_ok=True) return output_dir def process_request(self, img_path, model_list): """ Process an image with the specified models and evaluate the result. Args: img_path (str): Path to the input image. model_list (list): List of model names to apply. Returns: dict: Dictionary with output path and quality score. Raises: FileNotFoundError: If the input image doesn't exist. """ # Verify the image path if not os.path.exists(img_path): raise FileNotFoundError(f"Image file not found: {img_path}") # Create a unique output directory output_dir = self.create_output_dir() # Process the image final_output = self.process_image_with_models(model_list, img_path, output_dir) # Evaluate the result score = self.iqa.get_iqa_score(final_output) return { "output_path": final_output, "score": score } # Initialize the model tester model_tester = None @app.route('/process_image', methods=['POST']) def process_image(): """ API endpoint for processing an image with specified models. Expects a JSON payload with: - img_path: Path to the input image - models: List of model names to apply Returns: - JSON with output_path and score """ global model_tester # Parse request data data = request.get_json() img_path = data.get('img_path') models_to_use = data.get('models', []) # Validate input if not img_path: return jsonify({"error": "Missing image path"}), 400 if not models_to_use: return jsonify({"error": "No models specified"}), 400 try: # Process the image result = model_tester.process_request(img_path, models_to_use) return jsonify(result) except FileNotFoundError as e: return jsonify({"error": str(e)}), 404 except Exception as e: return jsonify({"error": f"Processing failed: {str(e)}"}), 500 def start_server(host='0.0.0.0', port=5010, model_names=None): """ Start the API server with specified models. Args: host (str): Host address to bind the server. port (int): Port to listen on. model_names (list): List of model names to load. If None, loads a default set. """ global model_tester # Initialize the model tester model_tester = ModelTester() # Define default models if none specified if model_names is None: model_names = [ 'scunet', 'real_esrgan', 'ridcp', 'idt', 'turbo_rain', 'turbo_night', 'retinexformer_lolv2', 'retinexformer_fivek' ] # Load the models model_tester.load_models(model_names) # Start the Flask application print(f"Starting API server at http://{host}:{port}") app.run(host=host, port=port) if __name__ == '__main__': # Start the server with default settings start_server()