nitesh501 commited on
Commit
d85ce22
·
verified ·
1 Parent(s): 02259e1

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -143
app.py DELETED
@@ -1,143 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from tinydit import TinyDit
5
- from vae import Vae
6
- from sampler import ddim_sample, num_timesteps
7
- import gradio as gr
8
- import numpy as np
9
- from PIL import Image
10
- import os
11
- import subprocess
12
-
13
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
-
15
- vae = Vae(latent_channels=16).to(device)
16
- tiny_dit = TinyDit(latent_channels=16, patch_size=2, dim=768, depth=12, num_classes=1).to(device)
17
-
18
- ckpt_url = "https://huggingface.co/nitesh501/tinydit/resolve/main/tinydit_.pth"
19
-
20
- if not os.path.exists("tinydit_.pth"):
21
- subprocess.run(["wget", ckpt_url])
22
-
23
- ckpt = torch.load("tinydit_.pth", map_location=torch.device(device))
24
-
25
- tiny_dit.load_state_dict(ckpt['model'])
26
- vae.load_state_dict(ckpt['vae'])
27
-
28
-
29
-
30
- @torch.inference_mode()
31
- def generate_image(steps=50, seed=-1):
32
- if seed == -1:
33
- seed = torch.randint(0, 2**32, (1,)).item()
34
-
35
- torch.manual_seed(int(seed))
36
-
37
- x = torch.randn(1, 16, 8, 8).to(device)
38
- x = x.clamp(-1, 1)
39
-
40
- timesteps = torch.linspace(
41
- num_timesteps - 1,
42
- 0,
43
- steps,
44
- device=device
45
- ).long()
46
-
47
- for i in range(len(timesteps) - 1):
48
-
49
- t = timesteps[i]
50
- t_prev = timesteps[i + 1]
51
-
52
- t_tensor = torch.full(
53
- (1,),
54
- t,
55
- device=device,
56
- dtype=torch.long
57
- )
58
-
59
- t_prev_tensor = torch.full(
60
- (1,),
61
- t_prev,
62
- device=device,
63
- dtype=torch.long
64
- )
65
-
66
- label = torch.tensor([0], device=device)
67
-
68
- x = ddim_sample(
69
- tiny_dit,
70
- x,
71
- t_tensor,
72
- t_prev_tensor,
73
- label
74
- )
75
-
76
- img = vae.decoder(x / 0.18215)
77
-
78
- img = (
79
- img.squeeze(0)
80
- .detach()
81
- .cpu()
82
- .numpy()
83
- .transpose(1, 2, 0)
84
- )
85
-
86
- img = (img / 2 + 0.5).clip(0, 1)
87
- img = (img * 255).astype(np.uint8)
88
-
89
- pil_img = Image.fromarray(img)
90
-
91
- return pil_img
92
-
93
-
94
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
95
-
96
- gr.Markdown(
97
- """
98
- # 🎨 TinyDiT AI Anime Face Generator
99
- Generate AI anime faces using TinyDiT model
100
- """
101
- )
102
-
103
- with gr.Row():
104
-
105
- with gr.Column(scale=1):
106
-
107
- steps = gr.Slider(
108
- minimum=0,
109
- maximum=100,
110
- value=50,
111
- step=1,
112
- label="Sampling Steps"
113
- )
114
-
115
- seed = gr.Number(
116
- value=-1,
117
- label="Seed",
118
- precision=0
119
- )
120
-
121
- generate_btn = gr.Button(
122
- "🚀 Generate Image",
123
- variant="primary",
124
- size="lg"
125
- )
126
-
127
- with gr.Column(scale=2):
128
-
129
- output_image = gr.Image(
130
- label="",
131
- type="pil",
132
- height=256,
133
- width=256,
134
- interactive=False
135
- )
136
-
137
- generate_btn.click(
138
- fn=generate_image,
139
- inputs=[steps, seed],
140
- outputs=output_image
141
- )
142
-
143
- demo.launch()