Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader, Dataset | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import os | |
| import numpy as np | |
| from tqdm import tqdm | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| #Feature Extractor (ResNet) | |
| resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) | |
| resnet = nn.Sequential(*list(resnet.children())[:-1]) | |
| resnet.eval().to(device) | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| class ImageDataset(Dataset): | |
| def __init__(self, img_dir): | |
| self.img_dir = img_dir | |
| self.img_names = [f for f in os.listdir(img_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] | |
| if len(self.img_names) == 0: | |
| raise ValueError(f"No images found in {img_dir}. Please upload images first!") | |
| def __len__(self): | |
| return len(self.img_names) | |
| def __getitem__(self, idx): | |
| img_path = os.path.join(self.img_dir, self.img_names[idx]) | |
| img = Image.open(img_path).convert("RGB") | |
| return transform(img) | |
| #Autoencoder Model | |
| class ConvAutoencoder(nn.Module): | |
| def __init__(self): | |
| super(ConvAutoencoder, self).__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Conv2d(3, 32, 3, stride=2, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), | |
| nn.ReLU(), | |
| ) | |
| self.decoder = nn.Sequential( | |
| nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1,output_padding=1), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1,output_padding=1), | |
| nn.ReLU(), | |
| nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1,output_padding=1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| x = self.encoder(x) | |
| x = self.decoder(x) | |
| return x | |
| #Training | |
| def train_autoencoder(data_dir="data/unlabeled", epochs=10, lr=1e-3): | |
| try: | |
| dataset = ImageDataset(data_dir) | |
| except ValueError as e: | |
| print(e) | |
| print("Tip: Upload images to data/unlabeled/ using the Colab file uploader or download sample images.") | |
| return | |
| dataloader = DataLoader(dataset, batch_size=16, shuffle=True) | |
| model = ConvAutoencoder().to(device) | |
| criterion = nn.MSELoss() | |
| optimizer = optim.Adam(model.parameters(), lr=lr) | |
| print(f"Training autoencoder on {len(dataset)} images...") | |
| for epoch in range(epochs): | |
| model.train() | |
| total_loss = 0 | |
| for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"): | |
| batch = batch.to(device) | |
| recon = model(batch) | |
| loss = criterion(recon, batch) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}") | |
| os.makedirs("models", exist_ok=True) | |
| torch.save(model.state_dict(),"models/autoencoder.pth") | |
| print("Autoencoder saved to models/autoencoder.pth") | |
| if __name__ == "__main__": | |
| train_autoencoder() | |