Spaces:
Build error
Build error
Included project files
Browse files- S0_PrepareDataset.py +109 -0
- S2_TimberDataset.py +78 -0
- S3_intermediateDataset.py +85 -0
- S4_Training.py +89 -0
- S5_Evaluation.ipynb +209 -0
- requirements.txt +3 -0
S0_PrepareDataset.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Download files
|
| 2 |
+
from zipfile import ZipFile
|
| 3 |
+
import tarfile
|
| 4 |
+
import requests
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import os
|
| 7 |
+
from joblib import Parallel, delayed
|
| 8 |
+
|
| 9 |
+
def listdir_full(path: str) -> list[str]:
|
| 10 |
+
return [f"{path}/{p}" for p in os.listdir(path)]
|
| 11 |
+
|
| 12 |
+
def download_file(url, pos):
|
| 13 |
+
local_filename = f"data/{url.split('/')[-1]}"
|
| 14 |
+
if not os.path.exists(local_filename):
|
| 15 |
+
with requests.get(url, stream=True) as r:
|
| 16 |
+
r.raise_for_status()
|
| 17 |
+
|
| 18 |
+
total_size = int(r.headers.get("content-length", 0))
|
| 19 |
+
block_size = 8192
|
| 20 |
+
with tqdm(total=total_size,unit="B", unit_scale=True, position=pos, desc=f"#{pos} {local_filename}", ncols=100) as p_bar:
|
| 21 |
+
with open(local_filename, 'wb') as f:
|
| 22 |
+
for chunk in r.iter_content(chunk_size=block_size):
|
| 23 |
+
p_bar.update(len(chunk))
|
| 24 |
+
f.write(chunk)
|
| 25 |
+
return local_filename
|
| 26 |
+
|
| 27 |
+
IMAGE_DIR = "data/image"
|
| 28 |
+
if not os.path.isdir(IMAGE_DIR):
|
| 29 |
+
links = ["http://www.inf.ufpr.br/lesoliveira/download/macroscopic0.zip",
|
| 30 |
+
"http://www.inf.ufpr.br/lesoliveira/download/macroscopic1.tar.gz",
|
| 31 |
+
"http://www.inf.ufpr.br/lesoliveira/download/macroscopic2.tar.gz",
|
| 32 |
+
"http://www.inf.ufpr.br/lesoliveira/download/macroscopic3.tar.gz",
|
| 33 |
+
"http://www.inf.ufpr.br/lesoliveira/download/macroscopic4.tar.gz",
|
| 34 |
+
]
|
| 35 |
+
archives : list[str] = Parallel(-1)(delayed(download_file)(l, i) for i , l in enumerate(links))
|
| 36 |
+
|
| 37 |
+
def unzip(file: str):
|
| 38 |
+
if file.endswith(".zip"):
|
| 39 |
+
with ZipFile(file) as zip_file: zip_file.extractall(IMAGE_DIR)
|
| 40 |
+
if file.endswith(".tar.gz"):
|
| 41 |
+
with tarfile.open(file, "r:gz") as tar: tar.extractall(IMAGE_DIR)
|
| 42 |
+
Parallel(-1)(delayed(unzip)(file) for file in archives)
|
| 43 |
+
# delete faulty images
|
| 44 |
+
[os.remove(f"{IMAGE_DIR}/{i}") for i in os.listdir(IMAGE_DIR) if i.startswith("._")]
|
| 45 |
+
|
| 46 |
+
from pathlib import Path
|
| 47 |
+
import shutil
|
| 48 |
+
|
| 49 |
+
images = os.listdir(IMAGE_DIR)
|
| 50 |
+
# Group by species
|
| 51 |
+
labels = ["Aspidosperma polyneuron", "Araucaria angustifolia", "Tabebuia sp.", "Cordia goeldiana", "Cordia sp.", "Hura crepitans", "Acrocarpus fraxinifolius", "Hymenaea sp.", "Peltogyne sp.", "Hymenolobium petraeum", "Myroxylon balsamum", "Dipteryx sp.", "Machaerium sp.", "Bowdichia sp.", "Mimosa scabrella", "Cedrelinga catenaeformis", "Goupia glabra", "Ocotea porosa", "Mezilaurus itauba", "Laurus nobilis", "Bertholethia excelsa", "Cariniana estrellensis", "Couratari sp.", "Carapa guianensis", "Cedrela fissilis", "Melia azedarach", "Swietenia macrophylla", "Brosimum paraense", "Bagassa guianensis", "Virola surinamensis", "Eucalyptus sp.", "Pinus sp.", "Podocarpus lambertii", "Grevilea robusta", "Balfourodendron riedelianum", "Euxylophora paraensis", "Micropholis venulosa", "Pouteria pachycarpa", "Manilkara huberi", "Erisma uncinatum", "Vochysia sp."]
|
| 52 |
+
def group_label(id:int, label:str):
|
| 53 |
+
id = f"{id+1:02d}"
|
| 54 |
+
class_dir = f"{IMAGE_DIR}/{label}"
|
| 55 |
+
Path(class_dir).mkdir(parents=True, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
imgs = [im for im in images if im.startswith(id)]
|
| 58 |
+
[shutil.move(f"{IMAGE_DIR}/{im}",f"{class_dir}/{im}") for im in imgs]
|
| 59 |
+
|
| 60 |
+
Parallel(-1)(delayed(group_label)(i,l) for i,l in enumerate(labels))
|
| 61 |
+
|
| 62 |
+
# Train, Test Split
|
| 63 |
+
train_dir = f"{IMAGE_DIR}/train"
|
| 64 |
+
test_dir = f"{IMAGE_DIR}/test"
|
| 65 |
+
test_full_dir = f"{IMAGE_DIR}/test_full"
|
| 66 |
+
|
| 67 |
+
Path(train_dir).mkdir(parents=True, exist_ok=True)
|
| 68 |
+
Path(test_dir).mkdir(parents=True, exist_ok=True)
|
| 69 |
+
Path(test_full_dir).mkdir(parents=True, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
train_ratio = 0.9
|
| 72 |
+
dirs = os.listdir(IMAGE_DIR)
|
| 73 |
+
|
| 74 |
+
def train_test_split(dir:str):
|
| 75 |
+
imgs = os.listdir(f"{IMAGE_DIR}/{dir}")
|
| 76 |
+
split_index = int(len(imgs) * train_ratio)
|
| 77 |
+
train, test = imgs[:split_index], imgs[split_index:]
|
| 78 |
+
Path(f"{train_dir}/{dir}").mkdir(parents=True, exist_ok=True)
|
| 79 |
+
Path(f"{test_dir}/{dir}").mkdir(parents=True, exist_ok=True)
|
| 80 |
+
Path(f"{test_full_dir}/{dir}").mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
[shutil.move(f"{IMAGE_DIR}/{dir}/{t}",f"{train_dir}/{dir}/{t}") for t in train]
|
| 83 |
+
[shutil.copy(f"{IMAGE_DIR}/{dir}/{t}",f"{test_full_dir}/{dir}/{t}") for t in test]
|
| 84 |
+
[shutil.move(f"{IMAGE_DIR}/{dir}/{t}",f"{test_dir}/{dir}/{t}") for t in test]
|
| 85 |
+
shutil.rmtree(f"{IMAGE_DIR}/{dir}")
|
| 86 |
+
|
| 87 |
+
[train_test_split(d) for d in dirs if d not in ["test", "train", "test_full"]]
|
| 88 |
+
|
| 89 |
+
# Split to patches
|
| 90 |
+
import cv2
|
| 91 |
+
import patchify
|
| 92 |
+
|
| 93 |
+
L = 816
|
| 94 |
+
# tiles = [im[x:x+M,y:y+N] for x in range(0,im.shape[0],M) for y in range(0,im.shape[1],N)]
|
| 95 |
+
dirs = os.listdir(train_dir)
|
| 96 |
+
def patch_dir(dir):
|
| 97 |
+
for img in os.listdir(dir):
|
| 98 |
+
path = f"{dir}/{img}"
|
| 99 |
+
img = cv2.imread(f"{dir}/{img}")
|
| 100 |
+
os.remove(path)
|
| 101 |
+
patches = patchify.patchify(img,(L,L,3),L)
|
| 102 |
+
w,h,_ = patches.shape[:3]
|
| 103 |
+
patches = patches.reshape(w*h,*patches.shape[3:])
|
| 104 |
+
path, ext = os.path.splitext(path)
|
| 105 |
+
for i, p in enumerate(patches):
|
| 106 |
+
cv2.imwrite(f"{path}_{i}{ext}",p)
|
| 107 |
+
|
| 108 |
+
Parallel(-1)(delayed(patch_dir)(d) for d in listdir_full(train_dir))
|
| 109 |
+
Parallel(-1)(delayed(patch_dir)(d) for d in listdir_full(test_dir ))
|
S2_TimberDataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
import cv2
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
import os
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import pandas as pd
|
| 9 |
+
# 3264 x 2448
|
| 10 |
+
|
| 11 |
+
DATA_DIR = "data/image/train"
|
| 12 |
+
labels = ["Aspidosperma polyneuron", "Araucaria angustifolia", "Tabebuia sp.", "Cordia goeldiana", "Cordia sp.", "Hura crepitans", "Acrocarpus fraxinifolius", "Hymenaea sp.", "Peltogyne sp.", "Hymenolobium petraeum", "Myroxylon balsamum", "Dipteryx sp.", "Machaerium sp.", "Bowdichia sp.", "Mimosa scabrella", "Cedrelinga catenaeformis", "Goupia glabra", "Ocotea porosa", "Mezilaurus itauba", "Laurus nobilis", "Bertholethia excelsa", "Cariniana estrellensis", "Couratari sp.", "Carapa guianensis", "Cedrela fissilis", "Melia azedarach", "Swietenia macrophylla", "Brosimum paraense", "Bagassa guianensis", "Virola surinamensis", "Eucalyptus sp.", "Pinus sp.", "Podocarpus lambertii", "Grevilea robusta", "Balfourodendron riedelianum", "Euxylophora paraensis", "Micropholis venulosa", "Pouteria pachycarpa", "Manilkara huberi", "Erisma uncinatum", "Vochysia sp."]
|
| 13 |
+
label2id = {label:id for id, label in enumerate(labels)}
|
| 14 |
+
|
| 15 |
+
def compile_image_df(data_dir:str, split_at = 0.9)-> pd.DataFrame:
|
| 16 |
+
dirs = os.listdir(data_dir)
|
| 17 |
+
columns=['Image_ID','Species']
|
| 18 |
+
train = pd.DataFrame(columns=columns)
|
| 19 |
+
val = pd.DataFrame(columns=columns)
|
| 20 |
+
for dir in dirs:
|
| 21 |
+
imgs = [(f"{data_dir}/{dir}/{img}", dir) for img in list(os.listdir(f"{data_dir}/{dir}"))]
|
| 22 |
+
length = len(imgs)
|
| 23 |
+
train_count = int(length * split_at)
|
| 24 |
+
train = pd.concat([train, pd.DataFrame(imgs[:train_count],columns=columns)])
|
| 25 |
+
val = pd.concat([val, pd.DataFrame(imgs[train_count:],columns=columns)])
|
| 26 |
+
|
| 27 |
+
return train, val
|
| 28 |
+
|
| 29 |
+
class TimberDataset(Dataset):
|
| 30 |
+
def __init__(self,
|
| 31 |
+
dataframe: pd.DataFrame,
|
| 32 |
+
is_train=False,
|
| 33 |
+
transform=None,
|
| 34 |
+
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.dataframe = dataframe
|
| 37 |
+
self.is_train = is_train
|
| 38 |
+
self.transform = transform
|
| 39 |
+
self.device = device
|
| 40 |
+
|
| 41 |
+
def __len__(self) -> int:
|
| 42 |
+
return len(self.dataframe)
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, idx: list[int]|Tensor):
|
| 45 |
+
if torch.is_tensor(idx):
|
| 46 |
+
idx = idx.tolist()
|
| 47 |
+
|
| 48 |
+
img_name = os.path.join(self.dataframe.iloc[idx,0])
|
| 49 |
+
image = cv2.imread(img_name)
|
| 50 |
+
image = Image.fromarray(image)
|
| 51 |
+
|
| 52 |
+
label = self.dataframe.iloc[idx,1]
|
| 53 |
+
label = label2id[label]
|
| 54 |
+
label = torch.tensor(int(label))
|
| 55 |
+
|
| 56 |
+
if self.transform:
|
| 57 |
+
image = self.transform(image)
|
| 58 |
+
return image.to(self.device), label.to(self.device)
|
| 59 |
+
|
| 60 |
+
def build_dataloader(
|
| 61 |
+
train_ratio = 0.9,
|
| 62 |
+
img_size = (640,640),
|
| 63 |
+
batch_size = 12,
|
| 64 |
+
) -> tuple[DataLoader,DataLoader]:
|
| 65 |
+
train_df, val_df = compile_image_df(DATA_DIR, split_at=train_ratio)
|
| 66 |
+
|
| 67 |
+
transform = transforms.Compose([
|
| 68 |
+
transforms.Resize(img_size),
|
| 69 |
+
transforms.ToTensor(),
|
| 70 |
+
])
|
| 71 |
+
|
| 72 |
+
train_loader = DataLoader(TimberDataset(train_df, is_train=True,transform=transform),
|
| 73 |
+
shuffle=True,
|
| 74 |
+
batch_size=batch_size)
|
| 75 |
+
val_loader = DataLoader(TimberDataset(val_df, is_train=True,transform=transform),
|
| 76 |
+
batch_size=batch_size)
|
| 77 |
+
|
| 78 |
+
return train_loader,val_loader
|
S3_intermediateDataset.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from time import time
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from S2_TimberDataset import build_dataloader
|
| 8 |
+
from typing import Callable
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
# 3264 x 2448
|
| 12 |
+
|
| 13 |
+
def write_random_lowercase(n):
|
| 14 |
+
min_lc = ord(b'a')
|
| 15 |
+
len_lc = 26
|
| 16 |
+
ba = bytearray(os.urandom(n))
|
| 17 |
+
for i, b in enumerate(ba):
|
| 18 |
+
ba[i] = min_lc + b % len_lc # convert 0..255 to 97..122
|
| 19 |
+
return ba.decode("utf-8")
|
| 20 |
+
|
| 21 |
+
INTERMEDIATE_DIR = "data/intermediate"
|
| 22 |
+
|
| 23 |
+
class IntermediateDataset(Dataset):
|
| 24 |
+
def __init__(self,
|
| 25 |
+
name,
|
| 26 |
+
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.name = name
|
| 29 |
+
self.device = device
|
| 30 |
+
files = os.listdir(f"{INTERMEDIATE_DIR}/{self.name}")
|
| 31 |
+
self.tensors = np.array([f for f in files if os.path.splitext(f)[-1] == ".pt"])
|
| 32 |
+
self.labels = np.array([f for f in files if os.path.splitext(f)[-1] == ".txt"])
|
| 33 |
+
|
| 34 |
+
def __len__(self) -> int:
|
| 35 |
+
return len(self.tensors)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx: list[int]|Tensor):
|
| 38 |
+
tensor = self.tensors[idx]
|
| 39 |
+
labels = self.labels[idx]
|
| 40 |
+
images = torch.load(f"{INTERMEDIATE_DIR}/{self.name}/{tensor}")
|
| 41 |
+
|
| 42 |
+
with open(f"{INTERMEDIATE_DIR}/{self.name}/{labels}", 'r') as f:
|
| 43 |
+
labels = f.readline().split("-")
|
| 44 |
+
labels = Tensor(list(map(int, labels)))
|
| 45 |
+
|
| 46 |
+
images = images.to(self.device)
|
| 47 |
+
labels = labels.to(device=self.device, dtype=torch.int64)
|
| 48 |
+
|
| 49 |
+
return images, labels
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def prepare_intermediate_dataset(pred: Callable, name: str, dataset: DataLoader, iterations = 1) -> None:
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
for _ in range(iterations):
|
| 55 |
+
for images, labels in tqdm(dataset):
|
| 56 |
+
out = pred(images)
|
| 57 |
+
labels = np.char.mod('%d', labels.cpu().numpy())
|
| 58 |
+
labels = '-'.join(labels)
|
| 59 |
+
|
| 60 |
+
file_name = f"{INTERMEDIATE_DIR}/{name}/{int(time())}_{write_random_lowercase(10)}"
|
| 61 |
+
torch.save(out,f"{file_name}.pt")
|
| 62 |
+
with open(f"{file_name}.txt", 'w') as f:
|
| 63 |
+
f.write(labels)
|
| 64 |
+
|
| 65 |
+
def build_intermediate_dataset_if_not_exists(pred_:Callable, name:str, dataset:DataLoader) -> None:
|
| 66 |
+
try: os.mkdir(INTERMEDIATE_DIR)
|
| 67 |
+
except: pass
|
| 68 |
+
try: os.mkdir(f"{INTERMEDIATE_DIR}/{name}")
|
| 69 |
+
except: pass
|
| 70 |
+
|
| 71 |
+
if os.listdir(f"{INTERMEDIATE_DIR}/{name}") == []:
|
| 72 |
+
IntermediateDataset.prepare_intermediate_dataset(pred_, name, dataset)
|
| 73 |
+
|
| 74 |
+
def intermediate_dataset(name:str) -> DataLoader:
|
| 75 |
+
return DataLoader(IntermediateDataset(name=name),batch_size=1)
|
| 76 |
+
|
| 77 |
+
if __name__ == '__main__':
|
| 78 |
+
train, val, test = build_dataloader(train_ratio= 0.01)
|
| 79 |
+
|
| 80 |
+
build_intermediate_dataset_if_not_exists(lambda x:x, "testing", train)
|
| 81 |
+
|
| 82 |
+
train_loader = DataLoader(IntermediateDataset("testing"),batch_size=1)
|
| 83 |
+
(i1,i2,i3), val = next(iter(train_loader))
|
| 84 |
+
"a"
|
| 85 |
+
|
S4_Training.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from time import time
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, Tensor
|
| 6 |
+
from S3_intermediateDataset import build_intermediate_dataset_if_not_exists, intermediate_dataset
|
| 7 |
+
from S2_TimberDataset import build_dataloader
|
| 8 |
+
from S1_CNN_Model import build_model
|
| 9 |
+
|
| 10 |
+
if __name__ == '__main__':
|
| 11 |
+
img_size = (320,320)
|
| 12 |
+
train_loader, val_loader = build_dataloader(
|
| 13 |
+
# train_ratio= 0.005,
|
| 14 |
+
img_size=img_size,
|
| 15 |
+
batch_size=16,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
build_intermediate_dataset_if_not_exists(lambda x:x, "train", train_loader)
|
| 19 |
+
build_intermediate_dataset_if_not_exists(lambda x:x, "val", val_loader)
|
| 20 |
+
|
| 21 |
+
train_loader = intermediate_dataset("train")
|
| 22 |
+
val_loader = intermediate_dataset("val")
|
| 23 |
+
|
| 24 |
+
model = build_model(img_size=img_size)
|
| 25 |
+
model.train()
|
| 26 |
+
|
| 27 |
+
criterion = nn.CrossEntropyLoss()
|
| 28 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
|
| 29 |
+
|
| 30 |
+
accuracies = []
|
| 31 |
+
|
| 32 |
+
n_epoch = 40
|
| 33 |
+
timer = time()
|
| 34 |
+
pbar_0 = tqdm(range(n_epoch), position=0, ncols=100)
|
| 35 |
+
pbar_0.set_description(f"epoch 1/{n_epoch}")
|
| 36 |
+
img = None
|
| 37 |
+
for epoch in pbar_0:
|
| 38 |
+
pbar_1 = tqdm(enumerate(train_loader), total=len(train_loader), position=1, ncols=100, leave=False)
|
| 39 |
+
for i, (images, labels) in pbar_1:
|
| 40 |
+
# # Reshape
|
| 41 |
+
images = images.reshape(images.shape[1:])
|
| 42 |
+
labels = labels.reshape(labels.shape[1:])
|
| 43 |
+
# Forward
|
| 44 |
+
out = model.forward(images)
|
| 45 |
+
loss = criterion.forward(out, labels)
|
| 46 |
+
|
| 47 |
+
# Backward
|
| 48 |
+
optimizer.zero_grad()
|
| 49 |
+
loss.backward()
|
| 50 |
+
|
| 51 |
+
optimizer.step()
|
| 52 |
+
|
| 53 |
+
if i % 10 == 0:
|
| 54 |
+
pbar_1.set_description(f" loss = {loss.item():.4f} ({(time() - timer)*1000:.4f} ms)")
|
| 55 |
+
timer = time()
|
| 56 |
+
|
| 57 |
+
n_correct = 0
|
| 58 |
+
n_samples = 0
|
| 59 |
+
pbar_0.set_description(f"epoch {epoch+1}/{n_epoch}, Validating . . .")
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
with tqdm(bar_format='{desc}{postfix}', position=1, leave=False) as val_desc:
|
| 63 |
+
tally = Counter()
|
| 64 |
+
for images, labels in tqdm(val_loader, position=2, ncols=100, leave=False):
|
| 65 |
+
# # Reshape
|
| 66 |
+
images = images.reshape(images.shape[1:])
|
| 67 |
+
labels = labels.reshape(labels.shape[1:])
|
| 68 |
+
|
| 69 |
+
x = model.forward(images)
|
| 70 |
+
|
| 71 |
+
_, predictions = torch.max(x,1)
|
| 72 |
+
tally += Counter(predictions.tolist())
|
| 73 |
+
n_samples += labels.shape[0]
|
| 74 |
+
n_correct += (predictions == labels).sum().item()
|
| 75 |
+
tally_desc = ' '.join([f"{n}:{c}" for n,c in tally.most_common()])[:80] + "..."
|
| 76 |
+
val_desc.set_description(f"{n_correct}/{n_samples} correct")
|
| 77 |
+
val_desc.set_postfix_str(tally_desc)
|
| 78 |
+
|
| 79 |
+
accuracy = f"{n_correct/n_samples * 100:.2f}%"
|
| 80 |
+
pbar_0.set_description(f"epoch {epoch+2}/{n_epoch}, accuracy: {accuracy}")
|
| 81 |
+
|
| 82 |
+
if len(accuracies) >= 3 and accuracy > max(accuracies):
|
| 83 |
+
torch.save(model,f"model_{epoch}.pt")
|
| 84 |
+
|
| 85 |
+
accuracies.append(accuracy)
|
| 86 |
+
|
| 87 |
+
torch.save(model,"model.pt")
|
| 88 |
+
|
| 89 |
+
print(accuracies)
|
S5_Evaluation.ipynb
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 2,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import torch\n",
|
| 10 |
+
"import os\n",
|
| 11 |
+
"from tqdm import tqdm\n",
|
| 12 |
+
"from S1_CNN_Model import CNN_Model\n",
|
| 13 |
+
"from S2_TimberDataset import TimberDataset, compile_image_df\n",
|
| 14 |
+
"from S3_intermediateDataset import build_intermediate_dataset_if_not_exists, intermediate_dataset\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"from torch.utils.data import DataLoader\n",
|
| 17 |
+
"from torchvision import transforms"
|
| 18 |
+
]
|
| 19 |
+
},
|
| 20 |
+
{
|
| 21 |
+
"cell_type": "code",
|
| 22 |
+
"execution_count": 2,
|
| 23 |
+
"metadata": {},
|
| 24 |
+
"outputs": [],
|
| 25 |
+
"source": [
|
| 26 |
+
"test_df, _ = compile_image_df(\"data/image/test\", split_at=1.0)\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"def listdir_full(path: str) -> list[str]:\n",
|
| 29 |
+
" return [f\"{path}/{p}\" for p in os.listdir(path)]\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"transform = transforms.Compose([\n",
|
| 32 |
+
" transforms.Resize((320,320)),\n",
|
| 33 |
+
" transforms.ToTensor(),\n",
|
| 34 |
+
"])\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"test_loader = DataLoader(TimberDataset(test_df, is_train=True,transform=transform),\n",
|
| 37 |
+
" shuffle=True,\n",
|
| 38 |
+
" batch_size=12)\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"build_intermediate_dataset_if_not_exists(lambda x:x, \"test\", test_loader)\n",
|
| 41 |
+
"test_loader = intermediate_dataset(\"test\") "
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": 3,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [
|
| 49 |
+
{
|
| 50 |
+
"name": "stderr",
|
| 51 |
+
"output_type": "stream",
|
| 52 |
+
"text": [
|
| 53 |
+
"100%|ββββββββββ| 17/17 [09:00<00:00, 31.82s/it]"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"name": "stdout",
|
| 58 |
+
"output_type": "stream",
|
| 59 |
+
"text": [
|
| 60 |
+
"[('ckpt/model_37.pt', 92.36111111111111), ('ckpt/model_36.pt', 91.39957264957265), ('ckpt/model.pt', 90.11752136752136), ('ckpt/model_23.pt', 90.11752136752136), ('ckpt/model_27.pt', 90.03739316239316), ('ckpt/model_19.pt', 89.95726495726495), ('ckpt/model_17.pt', 88.56837606837607), ('ckpt/model_16.pt', 87.79380341880342), ('ckpt/model_15.pt', 85.79059829059828), ('ckpt/model_13.pt', 85.04273504273505), ('ckpt/model_11.pt', 85.01602564102564), ('ckpt/model_14.pt', 84.58867521367522), ('ckpt/model_7.pt', 79.8076923076923), ('ckpt/model_6.pt', 79.27350427350427), ('ckpt/model_5.pt', 75.58760683760684), ('ckpt/model_4.pt', 70.86004273504274), ('ckpt/model_3.pt', 63.67521367521367)]\n"
|
| 61 |
+
]
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"name": "stderr",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"\n"
|
| 68 |
+
]
|
| 69 |
+
}
|
| 70 |
+
],
|
| 71 |
+
"source": [
|
| 72 |
+
"def evaluate_model(model: CNN_Model):\n",
|
| 73 |
+
" n_correct = 0\n",
|
| 74 |
+
" n_samples = 0\n",
|
| 75 |
+
"\n",
|
| 76 |
+
" with torch.no_grad(): \n",
|
| 77 |
+
" with tqdm(test_loader, position=1, leave=False) as pb:\n",
|
| 78 |
+
" for images, labels in pb:\n",
|
| 79 |
+
" images = images.reshape(images.shape[1:])\n",
|
| 80 |
+
" labels = labels.reshape(labels.shape[1:])\n",
|
| 81 |
+
"\n",
|
| 82 |
+
" x = model.forward(images)\n",
|
| 83 |
+
" _, predictions = torch.max(x,1)\n",
|
| 84 |
+
" \n",
|
| 85 |
+
" n_samples += labels.shape[0]\n",
|
| 86 |
+
" n_correct += (predictions == labels).sum().item()\n",
|
| 87 |
+
"\n",
|
| 88 |
+
" pb.set_description(f\"{n_correct}/{n_samples} correct predictions ({n_correct/n_samples*100 :.2f}%)\")\n",
|
| 89 |
+
"\n",
|
| 90 |
+
" return n_correct/n_samples*100\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"accuracies = [(model_path, evaluate_model(torch.load(model_path))) for model_path in tqdm(listdir_full(\"ckpt\"), position=0)]\n",
|
| 93 |
+
"accuracies.sort(key = lambda x : x[1], reverse=True)\n",
|
| 94 |
+
"accuracies"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": 4,
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"test_loader = DataLoader(TimberDataset(test_df, is_train=True,transform=transform),\n",
|
| 104 |
+
" batch_size=12)\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"build_intermediate_dataset_if_not_exists(lambda x:x, \"test_full\", test_loader)\n",
|
| 107 |
+
"test_loader = intermediate_dataset(\"test_full\") "
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": 6,
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [
|
| 115 |
+
{
|
| 116 |
+
"name": "stderr",
|
| 117 |
+
"output_type": "stream",
|
| 118 |
+
"text": [
|
| 119 |
+
"100%|ββββββββββ| 17/17 [08:32<00:00, 30.14s/it]"
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"name": "stdout",
|
| 124 |
+
"output_type": "stream",
|
| 125 |
+
"text": [
|
| 126 |
+
"[('ckpt/model_37.pt', 97.11538461538461), ('ckpt/model_23.pt', 96.15384615384616), ('ckpt/model_36.pt', 96.15384615384616), ('ckpt/model_27.pt', 95.51282051282051), ('ckpt/model_19.pt', 95.1923076923077), ('ckpt/model.pt', 94.23076923076923), ('ckpt/model_17.pt', 94.23076923076923), ('ckpt/model_16.pt', 93.26923076923077), ('ckpt/model_15.pt', 92.94871794871796), ('ckpt/model_11.pt', 91.98717948717949), ('ckpt/model_13.pt', 91.34615384615384), ('ckpt/model_14.pt', 91.02564102564102), ('ckpt/model_7.pt', 88.78205128205127), ('ckpt/model_6.pt', 86.85897435897436), ('ckpt/model_5.pt', 82.05128205128204), ('ckpt/model_4.pt', 78.84615384615384), ('ckpt/model_3.pt', 70.83333333333334)]\n"
|
| 127 |
+
]
|
| 128 |
+
},
|
| 129 |
+
{
|
| 130 |
+
"name": "stderr",
|
| 131 |
+
"output_type": "stream",
|
| 132 |
+
"text": [
|
| 133 |
+
"\n"
|
| 134 |
+
]
|
| 135 |
+
}
|
| 136 |
+
],
|
| 137 |
+
"source": [
|
| 138 |
+
"def full_images_evaluation(model: CNN_Model):\n",
|
| 139 |
+
" model.image_size = (320,320)\n",
|
| 140 |
+
" n_correct = 0\n",
|
| 141 |
+
" n_samples = 0\n",
|
| 142 |
+
"\n",
|
| 143 |
+
" with torch.no_grad(): \n",
|
| 144 |
+
" with tqdm(test_loader, position=1, leave = False) as pb:\n",
|
| 145 |
+
" for images, labels in pb:\n",
|
| 146 |
+
" images = images.reshape(images.shape[1:])\n",
|
| 147 |
+
" labels = labels.reshape(labels.shape[1:])\n",
|
| 148 |
+
"\n",
|
| 149 |
+
" assert torch.all(labels == labels[0]).item()\n",
|
| 150 |
+
" label = labels[0]\n",
|
| 151 |
+
"\n",
|
| 152 |
+
" x = model.forward(images)\n",
|
| 153 |
+
" _, preds = torch.max(x,1)\n",
|
| 154 |
+
" pred = torch.mode(preds,0).values\n",
|
| 155 |
+
" \n",
|
| 156 |
+
" n_samples += 1\n",
|
| 157 |
+
" n_correct += (pred == label).item()\n",
|
| 158 |
+
"\n",
|
| 159 |
+
" pb.set_description(f\"{n_correct}/{n_samples} correct predictions ({n_correct/n_samples*100 :.2f}%)\")\n",
|
| 160 |
+
"\n",
|
| 161 |
+
" return n_correct/n_samples*100\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"full_accuracies = [(model_path, full_images_evaluation(torch.load(model_path)))\n",
|
| 164 |
+
" for model_path in tqdm(listdir_full(\"ckpt\"), position=0)]\n",
|
| 165 |
+
"full_accuracies.sort(key = lambda x : x[1], reverse=True)\n",
|
| 166 |
+
"print(full_accuracies)"
|
| 167 |
+
]
|
| 168 |
+
},
|
| 169 |
+
{
|
| 170 |
+
"cell_type": "code",
|
| 171 |
+
"execution_count": 9,
|
| 172 |
+
"metadata": {},
|
| 173 |
+
"outputs": [
|
| 174 |
+
{
|
| 175 |
+
"data": {
|
| 176 |
+
"text/plain": [
|
| 177 |
+
"<zip at 0x286ec807c00>"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
"execution_count": 9,
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"output_type": "execute_result"
|
| 183 |
+
}
|
| 184 |
+
],
|
| 185 |
+
"source": []
|
| 186 |
+
}
|
| 187 |
+
],
|
| 188 |
+
"metadata": {
|
| 189 |
+
"kernelspec": {
|
| 190 |
+
"display_name": "fyp",
|
| 191 |
+
"language": "python",
|
| 192 |
+
"name": "python3"
|
| 193 |
+
},
|
| 194 |
+
"language_info": {
|
| 195 |
+
"codemirror_mode": {
|
| 196 |
+
"name": "ipython",
|
| 197 |
+
"version": 3
|
| 198 |
+
},
|
| 199 |
+
"file_extension": ".py",
|
| 200 |
+
"mimetype": "text/x-python",
|
| 201 |
+
"name": "python",
|
| 202 |
+
"nbconvert_exporter": "python",
|
| 203 |
+
"pygments_lexer": "ipython3",
|
| 204 |
+
"version": "3.10.13"
|
| 205 |
+
}
|
| 206 |
+
},
|
| 207 |
+
"nbformat": 4,
|
| 208 |
+
"nbformat_minor": 2
|
| 209 |
+
}
|
requirements.txt
CHANGED
|
@@ -9,3 +9,6 @@ patchify
|
|
| 9 |
pillow
|
| 10 |
torch
|
| 11 |
torchvision
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
pillow
|
| 10 |
torch
|
| 11 |
torchvision
|
| 12 |
+
tqdm
|
| 13 |
+
requests
|
| 14 |
+
joblib
|