serviceadvisor / training /dataset.py
viswanani's picture
Upload 22 files
1c7bc31 verified
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as T
class CarIssuesDataset(Dataset):
def __init__(self, csv_path, img_root, labels, transform=None, text_col="customer_text"):
self.df = pd.read_csv(csv_path)
self.img_root = img_root
self.labels = labels
self.transform = transform or T.Compose([T.Resize((224,224)), T.ToTensor()])
self.text_col = text_col
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_path = row['image_path']
if not str(img_path).startswith(self.img_root):
import os
img_path = os.path.join(self.img_root, img_path)
try:
img = Image.open(img_path).convert("RGB")
except Exception:
import numpy as np
img = Image.fromarray((np.zeros((224,224,3))+255).astype("uint8"))
x = self.transform(img)
y = [1 if row["issue_label"] == l else 0 for l in self.labels]
text = str(row.get(self.text_col, ""))
meta = {
"car_make": row.get("car_make", ""),
"car_model": row.get("car_model", ""),
"car_year": row.get("car_year", ""),
"mileage_km": row.get("mileage_km", ""),
}
return x, y, text, meta