Spaces:
Running
Running
File size: 2,662 Bytes
718c4ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import json
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class AuctionDatasetFromJSON(Dataset):
def __init__(self, json_path: str, root_dir: str, transform=None, max_samples=None):
"""
json_path: dataset/dataset.json
root_dir: dataset/raw_data
"""
with open(json_path, 'r', encoding='utf-8') as f:
self.data = json.load(f)
if max_samples:
self.data = self.data[:max_samples]
self.root_dir = Path(root_dir)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
auction = self.data[idx]
# Ścieżka do zdjęcia
img_path = self.root_dir / auction['folder_path'] / auction['images'][0]
try:
img = Image.open(img_path).convert('RGB')
except Exception as e:
print(f"Błąd wczytywania {img_path}: {e}")
# Fallback: czarne zdjęcie
img = Image.new('RGB', (224, 224), color='black')
if self.transform:
img = self.transform(img)
# Tekst: title + opis
text = f"{auction.get('title', '')} {auction.get('description', '')}"
return {
'image': img,
'text': text,
'platform': auction['platform'],
'title': auction['title'],
'id': auction['id'],
'label': torch.tensor(auction.get('label', 0), dtype=torch.long),
'folder_path': auction['folder_path']
}
# Transformacje
get_transforms = lambda: transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
if __name__ == '__main__':
print("Testowanie DataLoadera...")
dataset = AuctionDatasetFromJSON(
json_path='../dataset/dataset.json',
root_dir='../dataset/raw_data',
transform=get_transforms(),
max_samples=5
)
print(f"✓ Dataset załadowany: {len(dataset)} próbek")
loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0)
for batch in loader:
print(f"\nBatch:")
print(f" - Image shape: {batch['image'].shape}")
print(f" - Texts: {len(batch['text'])}")
print(f" - Platforms: {batch['platform']}")
print(f" - Labels: {batch['label']}")
print(f" - Example text: {batch['text'][0][:100]}...")
break |