caikybaldo999 commited on
Commit
f2c826a
·
verified ·
1 Parent(s): 10827e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -33
app.py CHANGED
@@ -1,37 +1,34 @@
1
- # Function to generate an image from a text prompt
2
- def generate_image_from_prompt(prompt, generator, device, z_dim):
3
- # A very simple way to map a prompt to a latent vector.
4
- # In a real application, you would use a more sophisticated text encoder.
5
- # Here, we just use a hash of the prompt to seed a random number generator.
6
- import hashlib
7
- seed = int(hashlib.sha256(prompt.encode('utf-8')).hexdigest(), 16) % (2**32 - 1)
8
- torch.manual_seed(seed)
9
 
10
- noise = torch.randn(1, z_dim, 1, 1).to(device)
11
- with torch.no_grad():
12
- fake_img = generator(noise).detach().cpu().squeeze(0)
13
- img = (fake_img.permute(1, 2, 0) + 1) / 2
14
- return img
 
 
15
 
16
- # Load the trained generator model
17
- # Make sure to run the training cell (VDcHP883uCpZ) at least once to save the model
18
- gen_file = os.path.join(model_path, "generator128.pth")
19
- if os.path.exists(gen_file):
20
- # Initialize the generator model
21
- G = Generator().to(device)
22
- G.load_state_dict(torch.load(gen_file))
23
- G.eval() # Set the generator to evaluation mode
24
- print("✅ Generator model loaded successfully!")
25
- else:
26
- print("⚠️ Generator model not found. Please run the training cell first.")
27
 
28
- # Example usage:
29
- if os.path.exists(gen_file):
30
- prompt = "cat and dog"
31
- generated_image = generate_image_from_prompt(prompt, G, device, z_dim)
 
 
 
 
32
 
33
- plt.figure(figsize=(4, 4))
34
- plt.axis("off")
35
- plt.title(f"Generated Image for Prompt: '{prompt}'")
36
- plt.imshow(generated_image)
37
- plt.show()
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import gradio as gr
5
+ from model import Generator
 
 
 
6
 
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Configurações
10
+ latent_dim = 100
11
+ generator = Generator(latent_dim).to(device)
12
+ generator.load_state_dict(torch.load("generator128.pth", map_location=device))
13
+ generator.eval()
14
 
15
+ # Função para gerar imagem
16
+ def generate_image(seed: int):
17
+ torch.manual_seed(seed)
18
+ noise = torch.randn(1, latent_dim, 1, 1, device=device)
19
+ with torch.no_grad():
20
+ fake_image = generator(noise).detach().cpu()
21
+ fake_image = (fake_image + 1) / 2 # Traz para [0, 1]
22
+ to_pil = transforms.ToPILImage()
23
+ return to_pil(fake_image.squeeze(0))
 
 
24
 
25
+ # Interface Gradio
26
+ demo = gr.Interface(
27
+ fn=generate_image,
28
+ inputs=gr.Slider(0, 99999, value=42, label="Seed"),
29
+ outputs=gr.Image(type="pil", label="Generated Image"),
30
+ title="ZYI 0.1 - CringeCoin Generator",
31
+ description="Enter a seed to generate a unique 128x128 image with ZYI 0.1."
32
+ )
33
 
34
+ demo.launch()