AlphaGPT-Image (DCGAN Faces)

Это собственная реализация DCGAN, обученная на датасете CelebA. Модель генерирует лица людей в разрешении 64x64 пикселя.

Характеристики:

  • Стиль: Old AI / Surrealism (характерная размытость и артефакты ранних нейросетей).
  • Архитектура: DCGAN (Deep Convolutional Generative Adversarial Network).
  • Обучение: 20 эпох на 200,000+ изображений CelebA через Kaggle GPU.
  • Параметры: $nz=100$, $ngf=64$, $ndf=64$.

Как использовать (Python / PyTorch)

Для запуска тебе понадобятся библиотеки torch, torchvision, matplotlib и huggingface_hub.

import torch
import torch.nn as nn
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

# 1. Архитектура Генератора
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8), nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4), nn.ReLU(True),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2), nn.ReLU(True),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, input): return self.main(input)

# 2. Загрузка весов
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
repo_id = "prostochel097/alphagpt-image" # Твой ID репозитория

weights_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
model = Generator().to(device)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()

# 3. Генерация
noise = torch.randn(1, 100, 1, 1, device=device)
with torch.no_grad():
    fake = model(noise).detach().cpu()

plt.figure(figsize=(5,5))
plt.axis("off")
plt.imshow(vutils.make_grid(fake, padding=2, normalize=True).permute(1,2,0))
plt.show()
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support