|
|
import os |
|
|
import io |
|
|
import time |
|
|
import torch |
|
|
import random |
|
|
import requests |
|
|
import numpy as np |
|
|
import geopandas as gpd |
|
|
import h5py |
|
|
import xarray as xr |
|
|
from torch import nn |
|
|
from torch.utils.data import Dataset |
|
|
import albumentations as A |
|
|
from torchvision.transforms import v2 |
|
|
import fsspec |
|
|
from PIL import Image |
|
|
|
|
|
from utils import ( |
|
|
shared_store, process_pool, write_last_updated, |
|
|
AddPoissonNoise, AddSaltPepperNoise |
|
|
) |
|
|
|
|
|
|
|
|
class DinoDataset(Dataset): |
|
|
""" |
|
|
π§ DinoDataset β resolution-agnostic loader for Core-Five π |
|
|
|
|
|
Streams random crops of HR satellite images from Hugging Face, |
|
|
creates clean (teacher) and augmented (student) views using |
|
|
Albumentations & torch. |
|
|
|
|
|
--- |
|
|
π€ Author: Gajesh Ladhar |
|
|
π LinkedIn: π https://www.linkedin.com/in/gajeshladhar/ |
|
|
π€ Hugging Face: π€ https://huggingface.co/gajeshladhar |
|
|
|
|
|
""" |
|
|
def __init__(self, imgsz, batch_size=1, queue_size=50): |
|
|
""" |
|
|
π Init the dataset with remote Core-Five metadata and start |
|
|
async patch fetching. |
|
|
|
|
|
Args: |
|
|
imgsz (int): Patch size (min 320 recommended) |
|
|
batch_size (int): Number of patches per batch |
|
|
queue_size (int): Max queue length for shared store |
|
|
""" |
|
|
if imgsz < 320: |
|
|
raise ValueError("imgsz must be β₯ 320 for stable patch extraction β got {}".format(imgsz)) |
|
|
self.imgsz = imgsz |
|
|
metadata_url = "https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/metadata.parquet" |
|
|
self.df_metadata = gpd.read_parquet(fsspec.open(metadata_url).open()) |
|
|
self.batch_size = batch_size |
|
|
self.queue_size = queue_size |
|
|
self.store = shared_store |
|
|
|
|
|
for _ in range(6): |
|
|
process_pool.submit(self.fetch_and_store) |
|
|
|
|
|
@staticmethod |
|
|
def transform(batch): |
|
|
""" |
|
|
ποΈ Apply augmentation pipeline to simulate degraded inputs |
|
|
for student; teacher gets clean view. Maintains shape consistency. |
|
|
|
|
|
Returns: |
|
|
Dict with 'student' and 'teacher' uint8 tensors |
|
|
""" |
|
|
augment_satellite = v2.Compose( |
|
|
[ |
|
|
v2.RandomResizedCrop(256, scale=(0.08, 1.0)), |
|
|
v2.RandomApply([v2.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), |
|
|
v2.RandomGrayscale(p=0.2), |
|
|
v2.RandomApply([v2.GaussianBlur(kernel_size=7, sigma=(0.1, 2.0))]), |
|
|
v2.RandomApply([v2.RandomSolarize(threshold=128)], p=0.2), |
|
|
v2.RandomHorizontalFlip(), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale=True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
] |
|
|
) |
|
|
students = [] |
|
|
for img in batch: |
|
|
student = [] |
|
|
for view in range(8): |
|
|
student_data = augment_satellite(Image.fromarray(img.transpose(1,2,0))) |
|
|
student.append(torch.tensor(student_data)) |
|
|
students.append(torch.stack(student)) |
|
|
|
|
|
return { |
|
|
"views": torch.stack(students), |
|
|
} |
|
|
|
|
|
def fetch_and_store(self): |
|
|
""" |
|
|
π Continuously samples random crops from Core-Five, augments |
|
|
them via `transform`, and updates the shared queue for training. |
|
|
""" |
|
|
np.random.seed(int.from_bytes(os.urandom(4), 'little')) |
|
|
while True: |
|
|
try: |
|
|
batch = [] |
|
|
for _ in range(self.batch_size): |
|
|
path = os.path.join("https://huggingface.co/datasets/gajeshladhar/core-five/resolve/main/", |
|
|
self.df_metadata.sample(n=1).path.iloc[0]) |
|
|
buffer = io.BytesIO(requests.get(path, headers={"User-Agent": "Mozilla/5.0"}).content) |
|
|
with h5py.File(buffer, "r") as f: |
|
|
x = f["hr/x"][:] |
|
|
y = f["hr/y"][:] |
|
|
data = f["/hr/data"][:] |
|
|
bands = list(range(data.shape[0])) |
|
|
|
|
|
ds = xr.DataArray(data, dims=['band', 'y', 'x'], coords=[bands, y, x]).astype("uint8") |
|
|
|
|
|
imgsz_half = self.imgsz // 2 |
|
|
yid = np.random.randint(imgsz_half, len(ds.y) - imgsz_half) |
|
|
xid = np.random.randint(imgsz_half, len(ds.x) - imgsz_half) |
|
|
ds = ds.isel(y=range(yid - imgsz_half, yid + imgsz_half), |
|
|
x=range(xid - imgsz_half, xid + imgsz_half)).compute() |
|
|
ds['y'], ds['x'] = np.linspace(ds.y.values[0], ds.y.values[-1], ds.shape[1]), \ |
|
|
np.linspace(ds.x.values[0], ds.x.values[-1], ds.shape[2]) |
|
|
|
|
|
batch.append(ds.data) |
|
|
|
|
|
result = DinoDataset.transform(batch) |
|
|
if len(self.store) >= self.queue_size: |
|
|
index = np.random.randint(0, self.queue_size - 1) |
|
|
self.store[index] = result |
|
|
else: |
|
|
self.store.append(result) |
|
|
|
|
|
|
|
|
if np.random.random() < 0.20: |
|
|
write_last_updated() |
|
|
except KeyboardInterrupt: |
|
|
break |
|
|
except Exception as e: |
|
|
print("ERROR:", e) |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
dataset = DinoDataset(imgsz=1696,batch_size=3,queue_size=1000) |
|
|
while True : |
|
|
print(len(dataset.store)) |
|
|
time.sleep(5) |