File size: 6,253 Bytes
691ba3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
import hashlib
import os
import pickle


def get_image_coordinates(H, W):
    """Generate normalized coordinate grid for image.

    Args:
        H: Image height
        W: Image width

    Returns:
        coords: Tensor of shape (H*W, 2) with normalized coordinates in [-1, 1]
    """
    x = torch.linspace(-1, 1, W)
    y = torch.linspace(-1, 1, H)

    # Create meshgrid
    Y, X = torch.meshgrid(y, x, indexing='ij')

    # Stack and reshape to (H*W, 2)
    coords = torch.stack([X, Y], dim=-1).reshape(-1, 2)

    return coords


def image_to_tensor(image):
    """Convert PIL Image to normalized tensor.

    Args:
        image: PIL Image

    Returns:
        Tensor of shape (H*W, 3) with values in [0, 1]
    """
    # Convert to RGB if not already
    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Convert to tensor and normalize to [0, 1]
    img_tensor = transforms.ToTensor()(image)  # (C, H, W)
    img_tensor = img_tensor.permute(1, 2, 0)   # (H, W, C)
    img_tensor = img_tensor.reshape(-1, 3)     # (H*W, 3)

    return img_tensor


def tensor_to_image(tensor, H, W):
    """Convert tensor back to PIL Image.

    Args:
        tensor: Tensor of shape (H*W, 3) with values in [0, 1]
        H: Image height
        W: Image width

    Returns:
        PIL Image
    """
    # Reshape to (H, W, C)
    img = tensor.reshape(H, W, 3)

    # Clamp to [0, 1]
    img = torch.clamp(img, 0, 1)

    # Convert to numpy and scale to [0, 255]
    img = (img.cpu().numpy() * 255).astype(np.uint8)

    # Convert to PIL Image
    return Image.fromarray(img)


def downsample_image(image, scale_factor):
    """Downsample image by scale_factor.

    Args:
        image: PIL Image
        scale_factor: Downsampling factor (e.g., 2 for half size)

    Returns:
        Downsampled PIL Image
    """
    W, H = image.size
    new_W = W // scale_factor
    new_H = H // scale_factor

    return image.resize((new_W, new_H), Image.BICUBIC)


def train_siren(model, coords, pixels, num_steps=2000, learning_rate=1e-4, device='cpu'):
    """Train SIREN model on image.

    Args:
        model: SIREN model
        coords: Coordinate tensor (H*W, 2)
        pixels: Pixel values tensor (H*W, 3)
        num_steps: Number of training steps
        learning_rate: Learning rate
        device: Device to train on

    Returns:
        Trained model and training losses
    """
    model = model.to(device)
    coords = coords.to(device)
    pixels = pixels.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    losses = []

    for step in range(num_steps):
        # Forward pass
        pred_pixels = model(coords)

        # Compute loss
        loss = torch.nn.functional.mse_loss(pred_pixels, pixels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        # Print progress
        if (step + 1) % 200 == 0:
            print(f"Step {step + 1}/{num_steps}, Loss: {loss.item():.6f}")

    return model, losses


def compute_psnr(img1, img2):
    """Compute Peak Signal-to-Noise Ratio between two images.

    Args:
        img1: First image tensor (H*W, 3) in [0, 1]
        img2: Second image tensor (H*W, 3) in [0, 1]

    Returns:
        PSNR value in dB
    """
    mse = torch.nn.functional.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()


def compute_mae(img1, img2):
    """Compute Mean Absolute Error between two images.

    Args:
        img1: First image tensor (H*W, 3) in [0, 1]
        img2: Second image tensor (H*W, 3) in [0, 1]

    Returns:
        MAE value
    """
    mae = torch.nn.functional.l1_loss(img1, img2)
    return mae.item()


def compute_ssim_simple(img1, img2, window_size=11):
    """Compute simplified SSIM between two images.

    Args:
        img1: First image tensor (H*W, 3) in [0, 1]
        img2: Second image tensor (H*W, 3) in [0, 1]
        window_size: Window size for local statistics

    Returns:
        SSIM value in [0, 1]
    """
    # Simplified SSIM - compute channel-wise
    c1 = 0.01 ** 2
    c2 = 0.03 ** 2

    mu1 = img1.mean()
    mu2 = img2.mean()

    sigma1_sq = ((img1 - mu1) ** 2).mean()
    sigma2_sq = ((img2 - mu2) ** 2).mean()
    sigma12 = ((img1 - mu1) * (img2 - mu2)).mean()

    ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / \
           ((mu1 ** 2 + mu2 ** 2 + c1) * (sigma1_sq + sigma2_sq + c2))

    return ssim.item()


def get_model_cache_path(image_path, scale_factor, training_steps, hidden_features, hidden_layers):
    """Generate cache path for trained model.

    Args:
        image_path: Path to image
        scale_factor: Upscaling factor
        training_steps: Number of training steps
        hidden_features: Network width
        hidden_layers: Network depth

    Returns:
        Cache file path
    """
    cache_dir = "model_cache"
    os.makedirs(cache_dir, exist_ok=True)

    # Extract image name from path (without extension)
    if "/" in image_path:
        image_name = image_path.split("/")[-1].split(".")[0]
    else:
        image_name = image_path.split(".")[0]

    # Create descriptive filename
    filename = f"{training_steps}steps_{scale_factor}x_{image_name}_h{hidden_features}_l{hidden_layers}.pkl"

    return os.path.join(cache_dir, filename)


def save_model(model, cache_path):
    """Save model to cache.

    Args:
        model: SIREN model
        cache_path: Path to save model
    """
    with open(cache_path, 'wb') as f:
        pickle.dump(model.state_dict(), f)
    print(f"Model saved to cache: {cache_path}")


def load_model(model, cache_path):
    """Load model from cache.

    Args:
        model: SIREN model (architecture must match)
        cache_path: Path to cached model

    Returns:
        Loaded model or None if cache doesn't exist
    """
    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as f:
            model.load_state_dict(pickle.load(f))
        print(f"Model loaded from cache: {cache_path}")
        return model
    return None