Spaces:
Build error
Build error
| import os | |
| from collections.abc import Iterable | |
| from pathlib import Path | |
| import pickle | |
| import torch | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| import fastai.vision.all as fv | |
| from .storage import ImgFeatureStorage | |
| def dummy_loss_func(x, y): | |
| return torch.tensor(0.) | |
| def get_label(file_path): | |
| return os.path.basename(file_path).split('_')[0] | |
| class FeatureExtractorModel(nn.Module): | |
| def __init__(self, original_model): | |
| super().__init__() | |
| self.features = nn.Sequential(*list(original_model.children())[:-1]) | |
| self.pooling = nn.AdaptiveMaxPool2d(1) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.pooling(x) | |
| return x.view(x.size(0), -1) | |
| class FeatureExtractor: | |
| def __init__(self, dataset_path=None, dls=None, item_tfms=None, label_func=get_label, n_epochs=5): | |
| item_tfms = item_tfms or [fv.Resize(224)] | |
| self.dataset_path = dataset_path | |
| self.dls = dls | |
| self.item_tfms = item_tfms | |
| self.label_func = label_func | |
| self.n_epochs = n_epochs | |
| if self.dataset_path and not self.dls: | |
| self.dls = fv.ImageDataLoaders.from_name_func( | |
| self.dataset_path, fv.get_image_files(self.dataset_path), valid_pct=0.2, seed=42, | |
| label_func=self.label_func, item_tfms=self.item_tfms) | |
| def from_dataset(cls, dataset_path, item_tfms=[fv.Resize(224)], label_func=get_label, n_epochs=5): | |
| return cls(dataset_path=dataset_path, item_tfms=item_tfms, label_func=label_func, n_epochs=n_epochs) | |
| def from_learner(cls, extractor): | |
| obj = cls(dls=extractor.dls) | |
| obj.extractor = extractor | |
| return obj | |
| def load(cls, filename, label_func=get_label, item_tfms=[fv.Resize(224)]): | |
| extractor = fv.load_learner(filename, cpu=False) | |
| dls = fv.ImageDataLoaders.from_name_func( | |
| extractor.dls.path, fv.get_image_files(extractor.dls.path), valid_pct=0.2, seed=42, | |
| label_func=label_func, item_tfms=item_tfms) | |
| extractor.dls = dls | |
| return cls.from_learner(extractor) | |
| def export(self, model_name, path=Path('.')): | |
| self.extractor.path = path | |
| self.extractor.export(model_name) | |
| def train(self, n_epochs=None): | |
| n_epochs = n_epochs or self.n_epochs | |
| self.classifier = self.train_classifier(n_epochs) | |
| self.extractor = self.get_extractor() | |
| def train_classifier(self, n_epochs=None): | |
| n_epochs = n_epochs or self.n_epochs | |
| classifier = fv.vision_learner(self.dls, fv.resnet18, metrics=fv.error_rate) | |
| classifier.fine_tune(n_epochs) | |
| return classifier | |
| def get_extractor(self): | |
| model = FeatureExtractorModel(self.classifier.model) | |
| extractor = fv.Learner(self.dls, model, loss_func=dummy_loss_func) | |
| return extractor | |
| def predict(self, input_images): | |
| if not isinstance(input_images, Iterable) or isinstance(input_images, str): | |
| input_images = [input_images] | |
| with self.extractor.no_bar(), self.extractor.no_logging(): | |
| dl = self.extractor.dls.test_dl(input_images) | |
| inp, features, _, dec = self.extractor.get_preds(dl=dl, with_input=True, with_decoded=True) | |
| return features | |
| def predict_for_dataset(self, dls=None): | |
| dls = dls or self.dls | |
| train_features, _ = self.extractor.get_preds(dl=dls.train) | |
| valid_features, _ = self.extractor.get_preds(dl=dls.valid) | |
| all_features = torch.cat([train_features, valid_features]) | |
| all_items = dls.train.items + dls.valid.items | |
| # Create a dictionary mapping image paths to features | |
| features = {image: activation.clone() for image, activation in zip(all_items, all_features)} | |
| return ImgFeatureStorage.from_features_dict(features) |