DSGAN / app.py
Cicikush's picture
Create app.py
1a09219 verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision.transforms import ToPILImage
import numpy as np
# --- Model Tanımı ve Sabitler ---
# Kaggle notebook'unun sonundaki PyTorch Generator modelini buraya kopyalıyoruz.
# Bu, model ağırlıklarını yüklemek için gereklidir.
latent_size = 128 # Gürültü vektörünün boyutu (notebook'taki gibi)
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) # Normalizasyon için kullanılan değerler
# Generator modelinin mimarisi (notebook ile birebir aynı olmalı)
generator = nn.Sequential(
# Giriş: (latent_size) x 1 x 1
nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# Boyut: 512 x 4 x 4
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# Boyut: 256 x 8 x 8
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# Boyut: 128 x 16 x 16
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# Boyut: 64 x 32 x 32
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
# Çıkış Boyutu: 3 x 64 x 64
)
# --- Model Yükleme ---
# Hugging Face Spaces'te genellikle CPU kullanılır, bu yüzden cihaza zorluyoruz.
device = torch.device("cpu")
# Modeli CPU'ya taşı
generator.to(device)
# Eğitilmiş ağırlıkları yükle
# Notebook'ta 'G.ckpt' olarak kaydedilmişti. 'generator.pth' olarak yeniden adlandırdığınızı varsayıyoruz.
# Eğer dosya adınız farklıysa, aşağıdaki satırı güncelleyin.
MODEL_PATH = 'G.ckpt'
try:
generator.load_state_dict(torch.load(MODEL_PATH, map_location=device))
print(f"'{MODEL_PATH}' dosyasından model ağırlıkları başarıyla yüklendi.")
except FileNotFoundError:
print(f"UYARI: '{MODEL_PATH}' dosyası bulunamadı. Model rastgele ağırlıklarla çalışacak.")
# Bu durumda uygulama çalışır ama anlamsız resimler üretir.
except Exception as e:
print(f"Model yüklenirken bir hata oluştu: {e}")
# Modeli inference (çıkarım) moduna al. Bu, BatchNorm ve Dropout katmanlarını doğru moda geçirir.
generator.eval()
# --- Görüntü Üretme ve Arayüz Fonksiyonları ---
def denorm(img_tensors):
"""Görüntü tensor'ünü [-1, 1] aralığından [0, 1] aralığına geri dönüştürür."""
return img_tensors * stats[1][0] + stats[0][0]
def generate_image(seed):
"""
Verilen bir seed'e göre rastgele gürültüden bir anime karakteri üretir.
"""
# Tekrarlanabilirlik için seed ayarla
# Gradio'nun state yönetimi nedeniyle bu her zaman mükemmel çalışmayabilir ama denemeye değer.
if seed:
torch.manual_seed(int(seed))
# Rastgele gürültü vektörü oluştur
noise = torch.randn(1, latent_size, 1, 1, device=device)
# Gradient hesaplamalarını devre dışı bırakarak performansı artır
with torch.no_grad():
# Gürültüyü modele vererek sahte bir resim üret
fake_image_tensor = generator(noise)
# Tensor'ü PIL Image formatına dönüştürmek için denormalize et
# Modelin çıktısı [-1, 1] (Tanh) aralığında.
denormalized_tensor = denorm(fake_image_tensor)
# Tensor'ü [0, 255] aralığında bir numpy array'e dönüştür ve PIL Image yap
pil_image = ToPILImage()(denormalized_tensor.squeeze(0)) # batch boyutunu (ilk boyutu) kaldır
return pil_image
# --- Gradio Arayüz Tanımı ---
title = "Anime Karakter Üretici (PyTorch GAN)"
description = """
Bu demo, bir GAN (Generative Adversarial Network) kullanarak rastgele anime karakter portreleri üretir.
Model, Kaggle'daki bir anime karakter veri seti ile eğitilmiştir.
Farklı karakterler görmek için aşağıdaki 'Seed' değerini değiştirin veya 'Rastgele Seed' düğmesine basarak rastgele bir değerle üretin.
"""
article = "<p style='text-align: center;'><a href='https://www.kaggle.com/code/melissamonfared/anime-character-generation-dsgan-gan' target='_blank'>Eğitim için kullanılan Kaggle Notebook'u</a></p>"
iface = gr.Interface(
fn=generate_image,
inputs=gr.Number(label="Seed (Rastgelelik için başlangıç değeri)", value=1234),
outputs=gr.Image(type="pil", label="Üretilen Karakter"),
title=title,
description=description,
article=article,
examples=[
[42],
[1337],
[2024]
]
)
# Arayüzü başlat
iface.launch()