File size: 5,496 Bytes
c71037b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a176fb5
c71037b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import io
import time
import torch
import random
import requests
import numpy as np
import geopandas as gpd
import h5py
import xarray as xr
from torch import nn
from torch.utils.data import Dataset
import albumentations as A
from torchvision.transforms import v2
import fsspec
from PIL import Image

from utils import (
    shared_store, process_pool, write_last_updated,
    AddPoissonNoise, AddSaltPepperNoise
)


class DinoDataset(Dataset):
    """
    🧠 DinoDataset β€” resolution-agnostic loader for Core-Five 🌍

    Streams random crops of HR satellite images from Hugging Face,
    creates clean (teacher) and augmented (student) views using
    Albumentations & torch.

     ---
    πŸ‘€ Author: Gajesh Ladhar
    πŸ”— LinkedIn: πŸ”— https://www.linkedin.com/in/gajeshladhar/
    πŸ€— Hugging Face: πŸ€— https://huggingface.co/gajeshladhar

    """
    def __init__(self, imgsz, batch_size=1, queue_size=50):
        """
        πŸ“ Init the dataset with remote Core-Five metadata and start
        async patch fetching.

        Args:
            imgsz (int): Patch size (min 320 recommended)
            batch_size (int): Number of patches per batch
            queue_size (int): Max queue length for shared store
        """
        if imgsz < 320:
            raise ValueError("imgsz must be β‰₯ 320 for stable patch extraction β€” got {}".format(imgsz))
        self.imgsz = imgsz
        metadata_url = "https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/metadata.parquet"
        self.df_metadata = gpd.read_parquet(fsspec.open(metadata_url).open())
        self.batch_size = batch_size
        self.queue_size = queue_size
        self.store = shared_store

        for _ in range(6):
            process_pool.submit(self.fetch_and_store)

    @staticmethod
    def transform(batch):
        """
        πŸŽ›οΈ Apply augmentation pipeline to simulate degraded inputs
        for student; teacher gets clean view. Maintains shape consistency.

        Returns:
            Dict with 'student' and 'teacher' uint8 tensors
        """
        augment_satellite = v2.Compose(
            [
                v2.RandomResizedCrop(256, scale=(0.08, 1.0)),
                v2.RandomApply([v2.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
                v2.RandomGrayscale(p=0.2),
                v2.RandomApply([v2.GaussianBlur(kernel_size=7, sigma=(0.1, 2.0))]),
                v2.RandomApply([v2.RandomSolarize(threshold=128)], p=0.2),
                v2.RandomHorizontalFlip(),
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
        students = []
        for img in batch:
            student = []
            for view in range(8):
                student_data = augment_satellite(Image.fromarray(img.transpose(1,2,0)))
                student.append(torch.tensor(student_data))
            students.append(torch.stack(student))

        return {
            "views": torch.stack(students),
        }

    def fetch_and_store(self):
        """
        πŸ”„ Continuously samples random crops from Core-Five, augments
        them via `transform`, and updates the shared queue for training.
        """
        np.random.seed(int.from_bytes(os.urandom(4), 'little'))
        while True:
            try:
                batch = []
                for _ in range(self.batch_size):
                    path = os.path.join("https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/",
                                        self.df_metadata.sample(n=1).path.iloc[0])
                    buffer = io.BytesIO(requests.get(path, headers={"User-Agent": "Mozilla/5.0"}).content)
                    with h5py.File(buffer, "r") as f:
                        x = f["hr/x"][:]
                        y = f["hr/y"][:]
                        data = f["/hr/data"][:]
                        bands = list(range(data.shape[0]))

                    ds = xr.DataArray(data, dims=['band', 'y', 'x'], coords=[bands, y, x]).astype("uint8")

                    imgsz_half = self.imgsz // 2
                    yid = np.random.randint(imgsz_half, len(ds.y) - imgsz_half)
                    xid = np.random.randint(imgsz_half, len(ds.x) - imgsz_half)
                    ds = ds.isel(y=range(yid - imgsz_half, yid + imgsz_half),
                                 x=range(xid - imgsz_half, xid + imgsz_half)).compute()
                    ds['y'], ds['x'] = np.linspace(ds.y.values[0], ds.y.values[-1], ds.shape[1]), \
                                       np.linspace(ds.x.values[0], ds.x.values[-1], ds.shape[2])

                    batch.append(ds.data)

                result = DinoDataset.transform(batch)
                if len(self.store) >= self.queue_size:
                    index = np.random.randint(0, self.queue_size - 1)
                    self.store[index] = result
                else:
                    self.store.append(result)

                # enable for getting recent updates
                if np.random.random() < 0.20:
                    write_last_updated()
            except KeyboardInterrupt:
                break
            except Exception as e:
                print("ERROR:", e)
                continue



if __name__=="__main__":
    dataset = DinoDataset(imgsz=1696,batch_size=3,queue_size=1000)
    while True : 
        print(len(dataset.store))
        time.sleep(5)