NguyenLe2004 commited on
Commit
62abfd9
·
1 Parent(s): 97761ce

Add application file

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import gradio as gr
4
+ import numpy as np
5
+ import urllib.request
6
+ from PIL import Image
7
+
8
+ # Tải mô hình GAN đã lưu
9
+ model_url = "https://github.com/NguyenLe2004/cat-image-generate/raw/refs/heads/main/generate/cat_face_generate_model.pt"
10
+ model_path = "cat_image_generate.pt"
11
+
12
+ # Tải mô hình nếu chưa tồn tại
13
+ if not os.path.exists(model_path):
14
+ urllib.request.urlretrieve(model_url, model_path)
15
+
16
+ # Load mô hình
17
+ model = torch.jit.load(model_path, map_location=torch.device('cpu'))
18
+ model.eval()
19
+
20
+ def generate_image():
21
+ # Tạo một vector ngẫu nhiên
22
+ latent_dim = 100
23
+ z = torch.randn(1, latent_dim, 1 , 1)
24
+
25
+ with torch.no_grad():
26
+ generated_image = model(z)
27
+ generated_image = generated_image.squeeze().numpy()
28
+ generated_image = generated_image*0.5 + 0.5
29
+ generated_image = np.transpose(generated_image, (1, 2, 0))
30
+
31
+ # Chuyển về kiểu uint8 để hiển thị
32
+ generated_image = (generated_image * 255).clip(0, 255).astype(np.uint8)
33
+
34
+ # Dùng PIL để chuyển thành ảnh
35
+ image = Image.fromarray(generated_image)
36
+ return image
37
+
38
+ # Giao diện Gradio
39
+ demo = gr.Interface(fn=generate_image, inputs=None, outputs="image")
40
+ demo.launch()