Spaces:
Build error
Build error
| import os | |
| import torch | |
| from time import time | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| from torch.utils.data import Dataset, DataLoader | |
| from S2_TimberDataset import build_dataloader | |
| from typing import Callable | |
| import pandas as pd | |
| import numpy as np | |
| # 3264 x 2448 | |
| def write_random_lowercase(n): | |
| min_lc = ord(b'a') | |
| len_lc = 26 | |
| ba = bytearray(os.urandom(n)) | |
| for i, b in enumerate(ba): | |
| ba[i] = min_lc + b % len_lc # convert 0..255 to 97..122 | |
| return ba.decode("utf-8") | |
| INTERMEDIATE_DIR = "data/intermediate" | |
| class IntermediateDataset(Dataset): | |
| def __init__(self, | |
| name, | |
| device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> None: | |
| super().__init__() | |
| self.name = name | |
| self.device = device | |
| files = os.listdir(f"{INTERMEDIATE_DIR}/{self.name}") | |
| self.tensors = np.array([f for f in files if os.path.splitext(f)[-1] == ".pt"]) | |
| self.labels = np.array([f for f in files if os.path.splitext(f)[-1] == ".txt"]) | |
| def __len__(self) -> int: | |
| return len(self.tensors) | |
| def __getitem__(self, idx: list[int]|Tensor): | |
| tensor = self.tensors[idx] | |
| labels = self.labels[idx] | |
| images = torch.load(f"{INTERMEDIATE_DIR}/{self.name}/{tensor}") | |
| with open(f"{INTERMEDIATE_DIR}/{self.name}/{labels}", 'r') as f: | |
| labels = f.readline().split("-") | |
| labels = Tensor(list(map(int, labels))) | |
| images = images.to(self.device) | |
| labels = labels.to(device=self.device, dtype=torch.int64) | |
| return images, labels | |
| def prepare_intermediate_dataset(pred: Callable, name: str, dataset: DataLoader, iterations = 1) -> None: | |
| with torch.no_grad(): | |
| for _ in range(iterations): | |
| for images, labels in tqdm(dataset): | |
| out = pred(images) | |
| labels = np.char.mod('%d', labels.cpu().numpy()) | |
| labels = '-'.join(labels) | |
| file_name = f"{INTERMEDIATE_DIR}/{name}/{int(time())}_{write_random_lowercase(10)}" | |
| torch.save(out,f"{file_name}.pt") | |
| with open(f"{file_name}.txt", 'w') as f: | |
| f.write(labels) | |
| def build_intermediate_dataset_if_not_exists(pred_:Callable, name:str, dataset:DataLoader) -> None: | |
| try: os.mkdir(INTERMEDIATE_DIR) | |
| except: pass | |
| try: os.mkdir(f"{INTERMEDIATE_DIR}/{name}") | |
| except: pass | |
| if os.listdir(f"{INTERMEDIATE_DIR}/{name}") == []: | |
| IntermediateDataset.prepare_intermediate_dataset(pred_, name, dataset) | |
| def intermediate_dataset(name:str) -> DataLoader: | |
| return DataLoader(IntermediateDataset(name=name),batch_size=1) | |
| if __name__ == '__main__': | |
| train, val, test = build_dataloader(train_ratio= 0.01) | |
| build_intermediate_dataset_if_not_exists(lambda x:x, "testing", train) | |
| train_loader = DataLoader(IntermediateDataset("testing"),batch_size=1) | |
| (i1,i2,i3), val = next(iter(train_loader)) | |
| "a" | |