LYL1015's picture
test
eea83e8
import torch
import torchvision.transforms as tfs
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import sys
from .nafnet import NAFNetLocal
# Add project root to Python path
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(project_root)
class SnowDataset(Dataset):
"""Simple dataset class for snow images"""
def __init__(self, input_path, desnow_input_size=None, model_name=None):
self.input_path = input_path
self.image_list = [f for f in os.listdir(input_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
self.transform = tfs.Compose([
tfs.ToTensor()
])
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
image_name = self.image_list[idx]
image_path = os.path.join(self.input_path, image_name)
image = Image.open(image_path).convert('RGB')
# Check and resize image if too large
max_size = 1500
if image.size[0] > max_size or image.size[1] > max_size:
ratio = max_size / max(image.size)
new_size = tuple(int(dim * ratio) for dim in image.size)
image = image.resize(new_size, Image.LANCZOS)
snow_input = self.transform(image)
image_id = os.path.splitext(image_name)[0]
return snow_input, torch.tensor([]), torch.tensor([]), image_id
def load_snowmaster_model(model_path, device='cuda'):
"""
Load the SnowMaster model
Args:
model_path (str): Path to model checkpoint
device (str): Device for inference, default is 'cuda'
Returns:
torch.nn.Module: Loaded model moved to specified device
"""
path = os.path.join(model_path, 'checkpoint_0318.pth')
# Create model architecture
model = NAFNetLocal(img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28], dec_blk_nums=[1, 1, 1, 1])
# Load checkpoint
checkpoint = torch.load(path, map_location=device)
# Get actual model weights
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif 'params' in checkpoint:
state_dict = checkpoint['params']
else:
state_dict = checkpoint
# Process weight keys, remove module prefix (module.)
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[7:]
new_state_dict[k] = v
try:
model.load_state_dict(new_state_dict)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
# Set to evaluation mode and move to specified device
model.eval()
return model.to(device)
def snowmaster_predict(model, input_img, output_dir, device='cuda'):
"""
Use SnowMaster model to perform snow removal on a single image
Args:
model (torch.nn.Module): Loaded model
input_img (str or PIL.Image.Image): Input image path or PIL image
output_dir (str): Output directory
device (str): Device for inference, default is 'cuda'
Returns:
str: Path to the saved output image
"""
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# If input is an image path, open the image
if isinstance(input_img, str):
img = Image.open(input_img).convert('RGB')
img_name = os.path.splitext(os.path.basename(input_img))[0]
else:
img = input_img
img_name = "predicted_image"
# Resize image if too large
max_size = 1500
if img.size[0] > max_size or img.size[1] > max_size:
ratio = max_size / max(img.size)
new_size = tuple(int(dim * ratio) for dim in img.size)
img = img.resize(new_size, Image.LANCZOS)
# Convert to tensor
transform = tfs.Compose([tfs.ToTensor()])
input_tensor = transform(img).unsqueeze(0)
# Inference
with torch.no_grad():
# Check image size, use patch-based processing if too large
if input_tensor.shape[-1] * input_tensor.shape[-2] > 2**24:
patch_size = 1024
stride = 1000
b, c, h, w = input_tensor.shape
output_tensor = torch.zeros_like(input_tensor)
for i in range(0, h - patch_size + 1, stride):
for j in range(0, w - patch_size + 1, stride):
patch = input_tensor[:, :, i:i+patch_size, j:j+patch_size]
patch_output = model(patch.to(device))
output_tensor[:, :, i:i+patch_size, j:j+patch_size] = patch_output.cpu()
else:
input_tensor = input_tensor.to(device)
output_tensor = model(input_tensor)
# Clamp output values to valid range
output_tensor = output_tensor.clamp_(0, 1)
# Convert to PIL image
output_image = tfs.ToPILImage()(output_tensor.squeeze().cpu())
# Save image
save_path = os.path.join(output_dir, f"{img_name}_desnowed.png")
output_image.save(save_path)
print(f"Processed: {save_path}")
return save_path
def inference(model_path, input_path, output_path, model_name, device='cuda'):
"""
Batch inference function:
Args:
model_path: Path to model checkpoint (string)
input_path: Path to input image folder (string)
output_path: Path to output image save folder (string)
model_name: Model name
device: Device for inference
"""
# Check if model file and input directory exist
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file does not exist: {model_path}")
if not os.path.exists(input_path):
raise FileNotFoundError(f"Input path does not exist: {input_path}")
# Create output directory
os.makedirs(output_path, exist_ok=True)
print(f"\nUsing model: {model_path}")
print(f"Input directory: {input_path}")
print(f"Output directory: {output_path}")
# Load current model
try:
model = load_snowmaster_model(model_path, device)
except Exception as e:
print(f"Error loading model: {str(e)}")
sys.exit(1)
dataset = SnowDataset(
input_path=input_path,
desnow_input_size=None,
model_name=model_name
)
dataloader = DataLoader(
dataset=dataset,
batch_size=1,
shuffle=False,
num_workers=0
)
with torch.no_grad():
for snow_input, _, _, image_id in dataloader:
# Check image size, use patch-based processing if too large
if snow_input.shape[-1] * snow_input.shape[-2] > 2**24:
patch_size = 1024 # Can be adjusted as needed
stride = 1000 # Can be adjusted as needed
b, c, h, w = snow_input.shape
output_tensor = torch.zeros_like(snow_input)
for i in range(0, h - patch_size + 1, stride):
for j in range(0, w - patch_size + 1, stride):
patch = snow_input[:, :, i:i+patch_size, j:j+patch_size]
patch_output = model(patch.to(device))
output_tensor[:, :, i:i+patch_size, j:j+patch_size] = patch_output.cpu()
else:
snow_input = snow_input.to(device)
output_tensor = model(snow_input)
output_tensor = output_tensor.clamp_(0, 1)
output_image = tfs.ToPILImage()(output_tensor.squeeze().cpu())
save_path = os.path.join(output_path, f"{image_id[0]}_desnowed.png")
output_image.save(save_path)
return save_path
if __name__ == "__main__":
ckpt_path = "weights/checkpoint_0318.pth"
input_img = "assests/1.jpg"
output_path = "output"
device = "cuda"
# Load model
model = load_snowmaster_model(ckpt_path, device)
# Single image inference
snowmaster_predict(model, input_img, output_path, device)