BeefyDoesAI commited on
Commit
6a46c1d
·
verified ·
1 Parent(s): 1197063

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ # --- 1. DEFINE THE BRAIN ---
7
+ class Generator(nn.Module):
8
+ def __init__(self):
9
+ super(Generator, self).__init__()
10
+ self.net = nn.Sequential(
11
+ nn.Linear(32, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.2),
12
+ nn.Linear(128, 64), nn.BatchNorm1d(64), nn.LeakyReLU(0.2),
13
+ nn.Linear(64, 1), nn.Sigmoid()
14
+ )
15
+ def forward(self, x): return self.net(x)
16
+
17
+ # --- 2. LOAD MODEL ---
18
+ # We download the weights directly from your model repo
19
+ MODEL_ID = "BeefyDoesAI/Number-E"
20
+ FILENAME = "NumberE.pth"
21
+
22
+ try:
23
+ weights_path = hf_hub_download(repo_id=MODEL_ID, filename=FILENAME)
24
+ # Spaces run on CPU by default (which is fine for this tiny model)
25
+ device = torch.device("cpu")
26
+ model = Generator().to(device)
27
+ model.load_state_dict(torch.load(weights_path, map_location=device))
28
+ model.eval()
29
+ except Exception as e:
30
+ raise RuntimeError(f"Failed to load model: {e}")
31
+
32
+ # --- 3. GENERATE FUNCTION ---
33
+ def generate(count, digits):
34
+ count = int(count)
35
+ digits = int(digits)
36
+
37
+ # Generate Noise
38
+ noise = torch.rand(count, 32).to(device)
39
+
40
+ # Run Model
41
+ with torch.no_grad():
42
+ output = model(noise)
43
+
44
+ # Process output
45
+ multiplier = 10 ** digits
46
+ raw = output.flatten().tolist()
47
+ integers = [str(int(val * multiplier)) for val in raw]
48
+
49
+ return ", ".join(integers)
50
+
51
+ # --- 4. UI ---
52
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
53
+ gr.Markdown(f"# Number-E Demo")
54
+ gr.Markdown("Generating numbers using a custom GAN architecture.")
55
+
56
+ with gr.Row():
57
+ qty = gr.Slider(1, 100, value=10, label="Quantity", step=1)
58
+ dig = gr.Slider(1, 10, value=2, label="Digits", step=1)
59
+ btn = gr.Button("Generate", variant="primary")
60
+
61
+ out = gr.Code(label="Output")
62
+ btn.click(generate, inputs=[qty, dig], outputs=out)
63
+
64
+ demo.launch()