File size: 21,682 Bytes
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99586dc
0917e8d
 
 
 
99586dc
 
0917e8d
 
 
 
 
 
 
 
99586dc
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99586dc
0917e8d
 
99586dc
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99586dc
 
 
 
 
 
 
 
 
 
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99586dc
0917e8d
 
 
 
 
 
 
 
 
 
 
7506f00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
"""
SparseC-AFM: AFM Super-Resolution Demo Application

A simple Gradio-based web app for experimenting with Swin Transformer
models for AFM (Atomic Force Microscopy) map super-resolution.

Usage:
    python app.py

Then open http://127.0.0.1:7860 in your browser.
"""

import io
import tempfile
from pathlib import Path
from typing import Tuple, Optional

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from scipy import stats as scipy_stats

# Add src to path for model imports
import sys
sys.path.insert(0, str(Path(__file__).parent))

from src.models.our_method.swin_cafm import SwinCAFM


# ─────────────────────────────────────────────────────────────────────────────
# Configuration
# ─────────────────────────────────────────────────────────────────────────────

MODEL_CONFIGS = {
    "2x": {"input_size": 64, "upscale": 2, "weights": "data/weights/2x/2x.pth"},
    "4x": {"input_size": 64, "upscale": 4, "weights": "data/weights/4x/4x.pth"},
    "8x": {"input_size": 32, "upscale": 8, "weights": "data/weights/8x/8x.pth"},
}

# Demo samples (center-cropped for fast processing)
DEMO_SAMPLES = {
    "MoS2 on SiO2 - Topography": "demo/MoS2_SiO2_Topography.npy",
    "MoS2 on SiO2 - Current": "demo/MoS2_SiO2_Current.npy",
    "MoS2 on Sapphire - Topography": "demo/MoS2_Sapphire_Topography.npy",
    "MoS2 on Sapphire - Current": "demo/MoS2_Sapphire_Current.npy",
}

COLORMAPS = ["viridis", "plasma", "inferno", "magma", "cividis", "hot", "coolwarm", "gray"]

SUPPORTED_FORMATS = {
    ".npy": "NumPy array",
    ".tif": "TIFF image",
    ".tiff": "TIFF image",
    ".png": "PNG image",
    ".jpg": "JPEG image",
    ".jpeg": "JPEG image",
    ".bmp": "BMP image",
    ".webp": "WebP image",
}


# ─────────────────────────────────────────────────────────────────────────────
# Device Detection
# ─────────────────────────────────────────────────────────────────────────────

def get_available_devices() -> list[str]:
    """Detect available compute devices."""
    devices = ["cpu"]
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        devices.append(f"cuda ({gpu_name})")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        devices.append("mps (Apple Silicon)")
    return devices


def parse_device(device_str: str) -> str:
    """Extract device name from display string."""
    if device_str.startswith("cuda"):
        return "cuda"
    elif device_str.startswith("mps"):
        return "mps"
    return "cpu"


# ─────────────────────────────────────────────────────────────────────────────
# Image I/O
# ─────────────────────────────────────────────────────────────────────────────

def load_map(filepath: str) -> np.ndarray:
    """
    Load a conductivity or topology map from various formats.
    Returns a 2D numpy array (grayscale).
    """
    ext = Path(filepath).suffix.lower()

    if ext == ".npy":
        data = np.load(filepath)
        # Handle 3D arrays (take first channel or squeeze)
        if data.ndim == 3:
            data = data[:, :, 0] if data.shape[2] <= 4 else data[0]
        return data.astype(np.float32)

    elif ext in [".tif", ".tiff"]:
        try:
            import tifffile
            data = tifffile.imread(filepath)
        except ImportError:
            # Fallback to PIL
            img = Image.open(filepath)
            data = np.array(img)
        if data.ndim == 3:
            data = data[:, :, 0]
        return data.astype(np.float32)

    elif ext in [".png", ".jpg", ".jpeg", ".bmp", ".webp"]:
        img = Image.open(filepath).convert("L")  # Convert to grayscale
        return np.array(img, dtype=np.float32)

    else:
        raise ValueError(f"Unsupported format: {ext}. Supported: {list(SUPPORTED_FORMATS.keys())}")


def apply_colormap(data: np.ndarray, cmap_name: str = "viridis") -> np.ndarray:
    """Apply a matplotlib colormap to grayscale data, returning RGB uint8."""
    # Normalize to [0, 1]
    normalized = (data - data.min()) / (data.max() - data.min() + 1e-8)

    # Apply colormap
    cmap = plt.get_cmap(cmap_name)
    colored = cmap(normalized)[:, :, :3]  # Drop alpha channel

    return (colored * 255).astype(np.uint8)


def save_to_format(data: np.ndarray, format: str, cmap_name: str = "viridis") -> str:
    """Save array to a temporary file in the specified format."""
    temp_dir = tempfile.gettempdir()

    if format == "npy":
        filepath = Path(temp_dir) / "upsampled_result.npy"
        np.save(filepath, data)

    elif format == "tiff":
        filepath = Path(temp_dir) / "upsampled_result.tiff"
        try:
            import tifffile
            tifffile.imwrite(filepath, data.astype(np.float32))
        except ImportError:
            # Fallback: save as 16-bit normalized
            normalized = (data - data.min()) / (data.max() - data.min() + 1e-8)
            img = Image.fromarray((normalized * 65535).astype(np.uint16))
            img.save(filepath)

    elif format == "png":
        filepath = Path(temp_dir) / "upsampled_result.png"
        colored = apply_colormap(data, cmap_name)
        Image.fromarray(colored).save(filepath)

    elif format == "csv":
        filepath = Path(temp_dir) / "upsampled_result.csv"
        np.savetxt(filepath, data, delimiter=",")

    else:
        raise ValueError(f"Unsupported export format: {format}")

    return str(filepath)


# ─────────────────────────────────────────────────────────────────────────────
# Model Management
# ─────────────────────────────────────────────────────────────────────────────

# Global model cache: {(scale, device): model}
_MODEL_CACHE: dict[Tuple[str, str], torch.nn.Module] = {}


def create_model(scale: str) -> torch.nn.Module:
    """Create model architecture for the given scale."""
    config = MODEL_CONFIGS[scale]
    upscale = config["upscale"]
    img_size = config["input_size"]

    return SwinCAFM(
        upscale=upscale,
        img_size=img_size,
        window_size=8,
        img_range=1.0,
        depths=[6, 6, 6, 6, 6, 6],
        embed_dim=180,
        num_heads=[6, 6, 6, 6, 6, 6],
        mlp_ratio=2,
        drop_path_rate=0.1,
        norm_layer=torch.nn.LayerNorm,
        upsampler="pixelshuffle",
        resi_connection="1conv",
    )


def get_model(scale: str, device: str) -> torch.nn.Module:
    """Load and cache a model for the given scale and device."""
    key = (scale, device)

    if key not in _MODEL_CACHE:
        config = MODEL_CONFIGS[scale]
        weights_path = Path(__file__).parent / config["weights"]

        if not weights_path.exists():
            raise FileNotFoundError(f"Weights not found: {weights_path}")

        # Load weights file
        loaded = torch.load(weights_path, map_location=device, weights_only=False)

        # Handle different save formats:
        # 1. Full model object (SwinCAFM) - use directly
        # 2. State dict (OrderedDict) - load into new model
        # 3. Dict with "params" key - extract and load
        if isinstance(loaded, SwinCAFM):
            model = loaded
        else:
            model = create_model(scale)
            state_dict = loaded
            if isinstance(state_dict, dict) and "params" in state_dict:
                state_dict = state_dict["params"]
            model.load_state_dict(state_dict, strict=False)

        model = model.to(device).eval()
        _MODEL_CACHE[key] = model

    return _MODEL_CACHE[key]


def clear_model_cache():
    """Clear the model cache to free memory."""
    global _MODEL_CACHE
    _MODEL_CACHE.clear()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None


# ─────────────────────────────────────────────────────────────────────────────
# Image Processing (Tiled)
# ─────────────────────────────────────────────────────────────────────────────

def pad_to_multiple(data: np.ndarray, tile_size: int) -> Tuple[np.ndarray, Tuple[int, int], str]:
    """
    Pad image so dimensions are multiples of tile_size.
    Returns (padded_data, original_shape, warning_message).
    """
    h, w = data.shape[:2]
    original_shape = (h, w)
    warnings = []

    # Calculate padding needed
    pad_h = (tile_size - h % tile_size) % tile_size
    pad_w = (tile_size - w % tile_size) % tile_size

    if pad_h > 0 or pad_w > 0:
        warnings.append(f"Input ({h}x{w}) padded to ({h + pad_h}x{w + pad_w}) for tiling.")
        data = np.pad(
            data,
            ((0, pad_h), (0, pad_w)),
            mode='reflect'  # Use reflect padding to avoid edge artifacts
        )

    warning = " ".join(warnings)
    return data, original_shape, warning


def process_tiled(
    data: np.ndarray,
    model: torch.nn.Module,
    tile_size: int,
    upscale: int,
    device: str,
) -> np.ndarray:
    """
    Process a large image by splitting into tiles, upsampling each, and stitching.

    Args:
        data: Input image (H, W), normalized to [0, 1]
        model: The upsampling model
        tile_size: Size of each tile (e.g., 64 for 2x/4x models)
        upscale: Upscaling factor (2, 4, or 8)
        device: Compute device

    Returns:
        Upsampled image (H*upscale, W*upscale)
    """
    h, w = data.shape
    out_h, out_w = h * upscale, w * upscale

    # Initialize output array
    output = np.zeros((out_h, out_w), dtype=np.float32)

    # Process each tile
    n_tiles_h = h // tile_size
    n_tiles_w = w // tile_size

    for i in range(n_tiles_h):
        for j in range(n_tiles_w):
            # Extract tile
            y_start = i * tile_size
            x_start = j * tile_size
            tile = data[y_start:y_start + tile_size, x_start:x_start + tile_size]

            # Run inference on tile
            X = torch.tensor(tile, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                tile_out = model(X).cpu().numpy()[0]

            # Place in output
            out_y = i * tile_size * upscale
            out_x = j * tile_size * upscale
            output[out_y:out_y + tile_size * upscale, out_x:out_x + tile_size * upscale] = tile_out

    return output


def center_crop(data: np.ndarray, target_size: int) -> np.ndarray:
    """
    Center crop the input to target_size x target_size.
    Used for demo samples for fast processing.
    """
    h, w = data.shape[:2]
    start_h = (h - target_size) // 2
    start_w = (w - target_size) // 2
    return data[start_h:start_h + target_size, start_w:start_w + target_size]


# ─────────────────────────────────────────────────────────────────────────────
# Statistics (Gwyddion-inspired)
# ─────────────────────────────────────────────────────────────────────────────

def compute_statistics(arr: np.ndarray) -> dict:
    """Compute Gwyddion-inspired surface statistics."""
    flat = arr.flatten()
    centered = arr - np.mean(arr)

    return {
        "Dimensions": f"{arr.shape[0]} x {arr.shape[1]} px",
        "Min": f"{arr.min():.6g}",
        "Max": f"{arr.max():.6g}",
        "Mean": f"{arr.mean():.6g}",
        "Median": f"{np.median(arr):.6g}",
        "Std Dev (Οƒ)": f"{arr.std():.6g}",
        "RMS Roughness (Rq)": f"{np.sqrt(np.mean(centered**2)):.6g}",
        "Avg Roughness (Ra)": f"{np.mean(np.abs(centered)):.6g}",
        "Peak-to-Valley (Rz)": f"{arr.max() - arr.min():.6g}",
        "Skewness": f"{scipy_stats.skew(flat):.4f}",
        "Kurtosis": f"{scipy_stats.kurtosis(flat):.4f}",
    }


# ─────────────────────────────────────────────────────────────────────────────
# Main Inference Pipeline
# ─────────────────────────────────────────────────────────────────────────────

def run_inference(
    file,
    demo_sample: str,
    scale: str,
    device_str: str,
    colormap: str,
    export_format: str,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], dict, Optional[str], str]:
    """
    Main inference function.

    Returns:
        - input_image for display
        - output_image for display
        - statistics dictionary
        - path to downloadable file
        - status/warning message
    """
    # Determine input source: demo sample or uploaded file
    use_demo = demo_sample and demo_sample != "Upload your own"

    if not use_demo and file is None:
        return None, None, {}, None, "Please select a demo sample or upload an image file."

    try:
        # Load input
        if use_demo:
            demo_path = Path(__file__).parent / DEMO_SAMPLES[demo_sample]
            data = np.load(demo_path)
        else:
            data = load_map(file.name)

        original_shape = data.shape
        original_min, original_max = data.min(), data.max()

        # Get model config
        config = MODEL_CONFIGS[scale]
        tile_size = config["input_size"]
        upscale_factor = config["upscale"]

        # Normalize to [0, 1]
        normalized = (data - original_min) / (original_max - original_min + 1e-8)

        # Load model
        device = parse_device(device_str)
        model = get_model(scale, device)

        if use_demo:
            # Demo samples: use center crop for fast processing
            cropped = center_crop(normalized, tile_size)

            # Single tile inference
            X = torch.tensor(cropped, dtype=torch.float32).unsqueeze(0).to(device)
            with torch.no_grad():
                output = model(X).cpu().numpy()[0]

            input_vis = cropped
            out_h, out_w = output.shape
            status = f"Demo: Center-cropped {original_shape[0]}x{original_shape[1]} to {tile_size}x{tile_size} -> {out_h}x{out_w} using {scale} model on {device}."

        else:
            # User uploads: full tiled processing (preserves all pixels)
            padded, orig_shape, pad_warning = pad_to_multiple(normalized, tile_size)
            padded_h, padded_w = padded.shape

            # Process using tiled approach
            output = process_tiled(padded, model, tile_size, upscale_factor, device)

            # Crop output back to original size (scaled)
            out_h = orig_shape[0] * upscale_factor
            out_w = orig_shape[1] * upscale_factor
            output = output[:out_h, :out_w]

            # Input visualization matches original
            input_vis = normalized[:orig_shape[0], :orig_shape[1]]

            # Build status message
            n_tiles = (padded_h // tile_size) * (padded_w // tile_size)
            status = f"Processed {original_shape[0]}x{original_shape[1]} in {n_tiles} tiles -> {out_h}x{out_w} using {scale} model on {device}."
            if pad_warning:
                status = f"Note: {pad_warning}\n{status}"

        # Denormalize output to original scale
        output_denorm = output * (original_max - original_min) + original_min

        # Apply colormap for visualization
        input_colored = apply_colormap(input_vis, colormap)
        output_colored = apply_colormap(output, colormap)

        # Compute statistics on denormalized output
        stats = compute_statistics(output_denorm)

        # Save to requested format
        download_path = save_to_format(output_denorm, export_format, colormap)

        return input_colored, output_colored, stats, download_path, status

    except Exception as e:
        return None, None, {}, None, f"Error: {str(e)}"


# ─────────────────────────────────────────────────────────────────────────────
# Gradio UI
# ─────────────────────────────────────────────────────────────────────────────

def create_app() -> gr.Blocks:
    """Create and configure the Gradio application."""

    with gr.Blocks(title="SparseC-AFM: AFM Super-Resolution") as app:

        gr.Markdown("""
        # SparseC-AFM: AFM Super-Resolution

        **Supported formats:** .npy, .tiff, .png, .jpg, .bmp, .webp
        """)

        with gr.Row():
            # Left column: inputs
            with gr.Column(scale=1):
                # Demo sample selector
                demo_dropdown = gr.Dropdown(
                    choices=["Upload your own"] + list(DEMO_SAMPLES.keys()),
                    value="Upload your own",
                    label="Select",
                )

                file_input = gr.File(
                    label="Or Upload Your Own (full resolution)",
                    file_types=[".npy", ".tif", ".tiff", ".png", ".jpg", ".jpeg", ".bmp", ".webp"],
                )

                with gr.Row():
                    scale_dropdown = gr.Dropdown(
                        choices=list(MODEL_CONFIGS.keys()),
                        value="4x",
                        label="Upscale Factor",
                    )
                    device_dropdown = gr.Dropdown(
                        choices=get_available_devices(),
                        value=get_available_devices()[0],
                        label="Compute Device",
                    )

                with gr.Row():
                    colormap_dropdown = gr.Dropdown(
                        choices=COLORMAPS,
                        value="viridis",
                        label="Colormap",
                    )
                    export_dropdown = gr.Dropdown(
                        choices=["npy", "tiff", "png", "csv"],
                        value="npy",
                        label="Download Format",
                    )

                run_button = gr.Button("Upsample", variant="primary", size="lg")

                status_box = gr.Textbox(
                    label="Status",
                    interactive=False,
                    lines=2,
                )

            # Right column: outputs
            with gr.Column(scale=2):
                # Image comparison - side by side
                with gr.Row():
                    input_image = gr.Image(
                        label="Original",
                        type="numpy",
                    )
                    output_image = gr.Image(
                        label="Upsampled",
                        type="numpy",
                    )

                with gr.Row():
                    # Statistics panel
                    stats_output = gr.JSON(
                        label="Sample Statistics",
                    )

                    # Download
                    download_output = gr.File(
                        label="Download Result",
                    )

        # Connect the interface
        run_button.click(
            fn=run_inference,
            inputs=[file_input, demo_dropdown, scale_dropdown, device_dropdown, colormap_dropdown, export_dropdown],
            outputs=[input_image, output_image, stats_output, download_output, status_box],
        )

    return app


# ─────────────────────────────────────────────────────────────────────────────
# Entry Point
# ─────────────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    app = create_app()
    app.launch(server_name="0.0.0.0", server_port=7860)