Cicikush commited on
Commit
1a09219
·
verified ·
1 Parent(s): 4ecbf25

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.transforms import ToPILImage
5
+ import numpy as np
6
+
7
+ # --- Model Tanımı ve Sabitler ---
8
+ # Kaggle notebook'unun sonundaki PyTorch Generator modelini buraya kopyalıyoruz.
9
+ # Bu, model ağırlıklarını yüklemek için gereklidir.
10
+
11
+ latent_size = 128 # Gürültü vektörünün boyutu (notebook'taki gibi)
12
+ stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) # Normalizasyon için kullanılan değerler
13
+
14
+ # Generator modelinin mimarisi (notebook ile birebir aynı olmalı)
15
+ generator = nn.Sequential(
16
+ # Giriş: (latent_size) x 1 x 1
17
+ nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
18
+ nn.BatchNorm2d(512),
19
+ nn.ReLU(True),
20
+ # Boyut: 512 x 4 x 4
21
+
22
+ nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
23
+ nn.BatchNorm2d(256),
24
+ nn.ReLU(True),
25
+ # Boyut: 256 x 8 x 8
26
+
27
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
28
+ nn.BatchNorm2d(128),
29
+ nn.ReLU(True),
30
+ # Boyut: 128 x 16 x 16
31
+
32
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
33
+ nn.BatchNorm2d(64),
34
+ nn.ReLU(True),
35
+ # Boyut: 64 x 32 x 32
36
+
37
+ nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
38
+ nn.Tanh()
39
+ # Çıkış Boyutu: 3 x 64 x 64
40
+ )
41
+
42
+ # --- Model Yükleme ---
43
+ # Hugging Face Spaces'te genellikle CPU kullanılır, bu yüzden cihaza zorluyoruz.
44
+ device = torch.device("cpu")
45
+
46
+ # Modeli CPU'ya taşı
47
+ generator.to(device)
48
+
49
+ # Eğitilmiş ağırlıkları yükle
50
+ # Notebook'ta 'G.ckpt' olarak kaydedilmişti. 'generator.pth' olarak yeniden adlandırdığınızı varsayıyoruz.
51
+ # Eğer dosya adınız farklıysa, aşağıdaki satırı güncelleyin.
52
+ MODEL_PATH = 'G.ckpt'
53
+ try:
54
+ generator.load_state_dict(torch.load(MODEL_PATH, map_location=device))
55
+ print(f"'{MODEL_PATH}' dosyasından model ağırlıkları başarıyla yüklendi.")
56
+ except FileNotFoundError:
57
+ print(f"UYARI: '{MODEL_PATH}' dosyası bulunamadı. Model rastgele ağırlıklarla çalışacak.")
58
+ # Bu durumda uygulama çalışır ama anlamsız resimler üretir.
59
+ except Exception as e:
60
+ print(f"Model yüklenirken bir hata oluştu: {e}")
61
+
62
+ # Modeli inference (çıkarım) moduna al. Bu, BatchNorm ve Dropout katmanlarını doğru moda geçirir.
63
+ generator.eval()
64
+
65
+ # --- Görüntü Üretme ve Arayüz Fonksiyonları ---
66
+
67
+ def denorm(img_tensors):
68
+ """Görüntü tensor'ünü [-1, 1] aralığından [0, 1] aralığına geri dönüştürür."""
69
+ return img_tensors * stats[1][0] + stats[0][0]
70
+
71
+ def generate_image(seed):
72
+ """
73
+ Verilen bir seed'e göre rastgele gürültüden bir anime karakteri üretir.
74
+ """
75
+ # Tekrarlanabilirlik için seed ayarla
76
+ # Gradio'nun state yönetimi nedeniyle bu her zaman mükemmel çalışmayabilir ama denemeye değer.
77
+ if seed:
78
+ torch.manual_seed(int(seed))
79
+
80
+ # Rastgele gürültü vektörü oluştur
81
+ noise = torch.randn(1, latent_size, 1, 1, device=device)
82
+
83
+ # Gradient hesaplamalarını devre dışı bırakarak performansı artır
84
+ with torch.no_grad():
85
+ # Gürültüyü modele vererek sahte bir resim üret
86
+ fake_image_tensor = generator(noise)
87
+
88
+ # Tensor'ü PIL Image formatına dönüştürmek için denormalize et
89
+ # Modelin çıktısı [-1, 1] (Tanh) aralığında.
90
+ denormalized_tensor = denorm(fake_image_tensor)
91
+
92
+ # Tensor'ü [0, 255] aralığında bir numpy array'e dönüştür ve PIL Image yap
93
+ pil_image = ToPILImage()(denormalized_tensor.squeeze(0)) # batch boyutunu (ilk boyutu) kaldır
94
+
95
+ return pil_image
96
+
97
+ # --- Gradio Arayüz Tanımı ---
98
+ title = "Anime Karakter Üretici (PyTorch GAN)"
99
+ description = """
100
+ Bu demo, bir GAN (Generative Adversarial Network) kullanarak rastgele anime karakter portreleri üretir.
101
+ Model, Kaggle'daki bir anime karakter veri seti ile eğitilmiştir.
102
+ Farklı karakterler görmek için aşağıdaki 'Seed' değerini değiştirin veya 'Rastgele Seed' düğmesine basarak rastgele bir değerle üretin.
103
+ """
104
+ article = "<p style='text-align: center;'><a href='https://www.kaggle.com/code/melissamonfared/anime-character-generation-dsgan-gan' target='_blank'>Eğitim için kullanılan Kaggle Notebook'u</a></p>"
105
+
106
+ iface = gr.Interface(
107
+ fn=generate_image,
108
+ inputs=gr.Number(label="Seed (Rastgelelik için başlangıç değeri)", value=1234),
109
+ outputs=gr.Image(type="pil", label="Üretilen Karakter"),
110
+ title=title,
111
+ description=description,
112
+ article=article,
113
+ examples=[
114
+ [42],
115
+ [1337],
116
+ [2024]
117
+ ]
118
+ )
119
+
120
+ # Arayüzü başlat
121
+ iface.launch()