Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision.transforms import ToPILImage
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# --- Model Tanımı ve Sabitler ---
|
| 8 |
+
# Kaggle notebook'unun sonundaki PyTorch Generator modelini buraya kopyalıyoruz.
|
| 9 |
+
# Bu, model ağırlıklarını yüklemek için gereklidir.
|
| 10 |
+
|
| 11 |
+
latent_size = 128 # Gürültü vektörünün boyutu (notebook'taki gibi)
|
| 12 |
+
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) # Normalizasyon için kullanılan değerler
|
| 13 |
+
|
| 14 |
+
# Generator modelinin mimarisi (notebook ile birebir aynı olmalı)
|
| 15 |
+
generator = nn.Sequential(
|
| 16 |
+
# Giriş: (latent_size) x 1 x 1
|
| 17 |
+
nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
|
| 18 |
+
nn.BatchNorm2d(512),
|
| 19 |
+
nn.ReLU(True),
|
| 20 |
+
# Boyut: 512 x 4 x 4
|
| 21 |
+
|
| 22 |
+
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
|
| 23 |
+
nn.BatchNorm2d(256),
|
| 24 |
+
nn.ReLU(True),
|
| 25 |
+
# Boyut: 256 x 8 x 8
|
| 26 |
+
|
| 27 |
+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
|
| 28 |
+
nn.BatchNorm2d(128),
|
| 29 |
+
nn.ReLU(True),
|
| 30 |
+
# Boyut: 128 x 16 x 16
|
| 31 |
+
|
| 32 |
+
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
|
| 33 |
+
nn.BatchNorm2d(64),
|
| 34 |
+
nn.ReLU(True),
|
| 35 |
+
# Boyut: 64 x 32 x 32
|
| 36 |
+
|
| 37 |
+
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
|
| 38 |
+
nn.Tanh()
|
| 39 |
+
# Çıkış Boyutu: 3 x 64 x 64
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# --- Model Yükleme ---
|
| 43 |
+
# Hugging Face Spaces'te genellikle CPU kullanılır, bu yüzden cihaza zorluyoruz.
|
| 44 |
+
device = torch.device("cpu")
|
| 45 |
+
|
| 46 |
+
# Modeli CPU'ya taşı
|
| 47 |
+
generator.to(device)
|
| 48 |
+
|
| 49 |
+
# Eğitilmiş ağırlıkları yükle
|
| 50 |
+
# Notebook'ta 'G.ckpt' olarak kaydedilmişti. 'generator.pth' olarak yeniden adlandırdığınızı varsayıyoruz.
|
| 51 |
+
# Eğer dosya adınız farklıysa, aşağıdaki satırı güncelleyin.
|
| 52 |
+
MODEL_PATH = 'G.ckpt'
|
| 53 |
+
try:
|
| 54 |
+
generator.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
| 55 |
+
print(f"'{MODEL_PATH}' dosyasından model ağırlıkları başarıyla yüklendi.")
|
| 56 |
+
except FileNotFoundError:
|
| 57 |
+
print(f"UYARI: '{MODEL_PATH}' dosyası bulunamadı. Model rastgele ağırlıklarla çalışacak.")
|
| 58 |
+
# Bu durumda uygulama çalışır ama anlamsız resimler üretir.
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Model yüklenirken bir hata oluştu: {e}")
|
| 61 |
+
|
| 62 |
+
# Modeli inference (çıkarım) moduna al. Bu, BatchNorm ve Dropout katmanlarını doğru moda geçirir.
|
| 63 |
+
generator.eval()
|
| 64 |
+
|
| 65 |
+
# --- Görüntü Üretme ve Arayüz Fonksiyonları ---
|
| 66 |
+
|
| 67 |
+
def denorm(img_tensors):
|
| 68 |
+
"""Görüntü tensor'ünü [-1, 1] aralığından [0, 1] aralığına geri dönüştürür."""
|
| 69 |
+
return img_tensors * stats[1][0] + stats[0][0]
|
| 70 |
+
|
| 71 |
+
def generate_image(seed):
|
| 72 |
+
"""
|
| 73 |
+
Verilen bir seed'e göre rastgele gürültüden bir anime karakteri üretir.
|
| 74 |
+
"""
|
| 75 |
+
# Tekrarlanabilirlik için seed ayarla
|
| 76 |
+
# Gradio'nun state yönetimi nedeniyle bu her zaman mükemmel çalışmayabilir ama denemeye değer.
|
| 77 |
+
if seed:
|
| 78 |
+
torch.manual_seed(int(seed))
|
| 79 |
+
|
| 80 |
+
# Rastgele gürültü vektörü oluştur
|
| 81 |
+
noise = torch.randn(1, latent_size, 1, 1, device=device)
|
| 82 |
+
|
| 83 |
+
# Gradient hesaplamalarını devre dışı bırakarak performansı artır
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
# Gürültüyü modele vererek sahte bir resim üret
|
| 86 |
+
fake_image_tensor = generator(noise)
|
| 87 |
+
|
| 88 |
+
# Tensor'ü PIL Image formatına dönüştürmek için denormalize et
|
| 89 |
+
# Modelin çıktısı [-1, 1] (Tanh) aralığında.
|
| 90 |
+
denormalized_tensor = denorm(fake_image_tensor)
|
| 91 |
+
|
| 92 |
+
# Tensor'ü [0, 255] aralığında bir numpy array'e dönüştür ve PIL Image yap
|
| 93 |
+
pil_image = ToPILImage()(denormalized_tensor.squeeze(0)) # batch boyutunu (ilk boyutu) kaldır
|
| 94 |
+
|
| 95 |
+
return pil_image
|
| 96 |
+
|
| 97 |
+
# --- Gradio Arayüz Tanımı ---
|
| 98 |
+
title = "Anime Karakter Üretici (PyTorch GAN)"
|
| 99 |
+
description = """
|
| 100 |
+
Bu demo, bir GAN (Generative Adversarial Network) kullanarak rastgele anime karakter portreleri üretir.
|
| 101 |
+
Model, Kaggle'daki bir anime karakter veri seti ile eğitilmiştir.
|
| 102 |
+
Farklı karakterler görmek için aşağıdaki 'Seed' değerini değiştirin veya 'Rastgele Seed' düğmesine basarak rastgele bir değerle üretin.
|
| 103 |
+
"""
|
| 104 |
+
article = "<p style='text-align: center;'><a href='https://www.kaggle.com/code/melissamonfared/anime-character-generation-dsgan-gan' target='_blank'>Eğitim için kullanılan Kaggle Notebook'u</a></p>"
|
| 105 |
+
|
| 106 |
+
iface = gr.Interface(
|
| 107 |
+
fn=generate_image,
|
| 108 |
+
inputs=gr.Number(label="Seed (Rastgelelik için başlangıç değeri)", value=1234),
|
| 109 |
+
outputs=gr.Image(type="pil", label="Üretilen Karakter"),
|
| 110 |
+
title=title,
|
| 111 |
+
description=description,
|
| 112 |
+
article=article,
|
| 113 |
+
examples=[
|
| 114 |
+
[42],
|
| 115 |
+
[1337],
|
| 116 |
+
[2024]
|
| 117 |
+
]
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Arayüzü başlat
|
| 121 |
+
iface.launch()
|