Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import time | |
| # Redirect PyTorch and XDG caches to E: drive to prevent C: drive filling up | |
| os.environ['TORCH_HOME'] = r'E:\torch_cache' | |
| os.environ['XDG_CACHE_HOME'] = r'E:\torch_cache' | |
| # Ensure src directory is in path for local imports | |
| sys.path.append(os.path.dirname(__file__)) | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import logging | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from model import UNet | |
| from dataset import get_val_transforms | |
| # Configure logging | |
| os.makedirs("logs", exist_ok=True) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler("logs/inference.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class VisionExtractPipeline: | |
| def __init__(self, model_path=None, device=None, image_size=256): | |
| self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.image_size = 256 # Fixed to 256 for stable results | |
| self.model = UNet().to(self.device) | |
| if model_path and os.path.exists(model_path): | |
| try: | |
| checkpoint = torch.load(model_path, map_location=self.device) | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| logger.info(f"Model loaded from {model_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to load weight: {e}. Using random weights.") | |
| else: | |
| logger.warning("No model path provided or file doesn't exist. Model is uninitialized.") | |
| self.model.eval() | |
| self.valid_formats = (".jpg", ".png", ".jpeg") | |
| def full_pipeline(self, image_path, output_path=None, save=True, display=False, custom_size=None): | |
| """Standard segmentation pipeline: Preprocess -> Predict -> Crop -> Upscale.""" | |
| if not image_path.lower().endswith(self.valid_formats): | |
| raise ValueError(f"Invalid image format: {image_path}. Supported: {self.valid_formats}") | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| raise FileNotFoundError(f"Could not read image: {image_path}") | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| h_orig, w_orig = image.shape[:2] | |
| # Determine image size for inference | |
| target_size = custom_size if custom_size else self.image_size | |
| transforms = get_val_transforms(image_size=target_size) | |
| # 1. Preprocess | |
| augmented = transforms(image=image) | |
| input_tensor = augmented['image'].unsqueeze(0).to(self.device) | |
| # 2. Prediction | |
| with torch.no_grad(): | |
| output = self.model(input_tensor) | |
| # Use raw probabilities for smooth alpha-matting | |
| prediction = torch.sigmoid(output).squeeze().cpu().numpy() | |
| # 3. Handle Padding Adjustment | |
| # Standard centering logic to undo PadIfNeeded (from dataset.py) | |
| scale = target_size / max(h_orig, w_orig) | |
| new_h, new_w = int(h_orig * scale), int(w_orig * scale) | |
| pad_top = (target_size - new_h) // 2 | |
| pad_left = (target_size - new_w) // 2 | |
| # Extract correctly aligned valid region | |
| valid_mask = prediction[pad_top:pad_top+new_h, pad_left:pad_left+new_w] | |
| # 4. Final Upscale & Matting | |
| # Resize mask back to original resolution for smooth, high-fidelity isolation | |
| final_mask = cv2.resize(valid_mask, (w_orig, h_orig), interpolation=cv2.INTER_LINEAR) | |
| # Apply mask to isolate subject | |
| isolated = (image * final_mask[:, :, None]).astype(np.uint8) | |
| # Save Output | |
| if save: | |
| if not output_path: | |
| timestamp = str(int(time.time())) | |
| output_folder = "outputs" | |
| os.makedirs(output_folder, exist_ok=True) | |
| output_path = os.path.join(output_folder, f"isolated_{timestamp}.png") | |
| else: | |
| os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else ".", exist_ok=True) | |
| output_image = Image.fromarray(isolated) | |
| output_image.save(output_path) | |
| if display: | |
| plt.figure(figsize=(10, 5)) | |
| plt.subplot(1, 2, 1); plt.imshow(image); plt.title("Original") | |
| plt.subplot(1, 2, 2); plt.imshow(isolated); plt.title("Isolated (Standard)") | |
| plt.show() | |
| return isolated, final_mask | |
| def batch_inference(self, folder_path, output_dir="outputs"): | |
| """Process all images in a folder using the standard pipeline.""" | |
| if not os.path.exists(folder_path): | |
| logger.error(f"Folder {folder_path} does not exist.") | |
| return | |
| images = [f for f in os.listdir(folder_path) if f.lower().endswith(self.valid_formats)] | |
| logger.info(f"Found {len(images)} images. Processing batch...") | |
| os.makedirs(output_dir, exist_ok=True) | |
| start_time = time.time() | |
| for img_name in images: | |
| img_path = os.path.join(folder_path, img_name) | |
| out_path = os.path.join(output_dir, f"isolated_{img_name.split('.')[0]}.png") | |
| try: | |
| self.full_pipeline(img_path, output_path=out_path, save=True, display=False) | |
| except Exception as e: | |
| logger.error(f"Error processing {img_name}: {e}") | |
| logger.info(f"Batch completed in: {time.time() - start_time:.2f}s.") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="VisionExtract Subject Isolation CLI") | |
| parser.add_argument("--image", type=str, help="Path to a single image") | |
| parser.add_argument("--dir", type=str, help="Path to a directory for batch processing") | |
| parser.add_argument("--output", type=str, help="Specify output path for single image") | |
| parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for batch") | |
| parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint (.pth)") | |
| parser.add_argument("--size", type=int, default=256, help="Inference resolution") | |
| parser.add_argument("--display", action="store_true", help="Display results using Matplotlib") | |
| args = parser.parse_args() | |
| # Model weight detection | |
| model_path = args.checkpoint | |
| if not model_path: | |
| checkpoint_dir = "checkpoints" | |
| if os.path.exists(checkpoint_dir): | |
| checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")] | |
| if checkpoints: | |
| if "best_model.pth" in checkpoints: | |
| model_path = os.path.join(checkpoint_dir, "best_model.pth") | |
| else: | |
| # Sort to get latest epoch | |
| checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]) if 'epoch' in x else 0) | |
| model_path = os.path.join(checkpoint_dir, checkpoints[-1]) | |
| pipeline = VisionExtractPipeline(model_path=model_path, image_size=args.size) | |
| if args.image: | |
| pipeline.full_pipeline(args.image, output_path=args.output, save=True, display=args.display) | |
| elif args.dir: | |
| pipeline.batch_inference(args.dir, output_dir=args.output_dir) | |
| else: | |
| print("Usage: python src/inference.py --image <path> [--output <path>] OR --dir <path>") | |