Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| # --- 1. DEFINE THE BRAIN --- | |
| class Generator(nn.Module): | |
| def __init__(self): | |
| super(Generator, self).__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(32, 128), nn.BatchNorm1d(128), nn.LeakyReLU(0.2), | |
| nn.Linear(128, 64), nn.BatchNorm1d(64), nn.LeakyReLU(0.2), | |
| nn.Linear(64, 1), nn.Sigmoid() | |
| ) | |
| def forward(self, x): return self.net(x) | |
| # --- 2. LOAD MODEL --- | |
| # We download the weights directly from your model repo | |
| MODEL_ID = "BeefyDoesAI/Number-E" | |
| FILENAME = "NumberE.pth" | |
| try: | |
| weights_path = hf_hub_download(repo_id=MODEL_ID, filename=FILENAME) | |
| # Spaces run on CPU by default (which is fine for this tiny model) | |
| device = torch.device("cpu") | |
| model = Generator().to(device) | |
| model.load_state_dict(torch.load(weights_path, map_location=device)) | |
| model.eval() | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {e}") | |
| # --- 3. GENERATE FUNCTION --- | |
| def generate(count, digits): | |
| count = int(count) | |
| digits = int(digits) | |
| # Generate Noise | |
| noise = torch.rand(count, 32).to(device) | |
| # Run Model | |
| with torch.no_grad(): | |
| output = model(noise) | |
| # Process output | |
| multiplier = 10 ** digits | |
| raw = output.flatten().tolist() | |
| integers = [str(int(val * multiplier)) for val in raw] | |
| return ", ".join(integers) | |
| # --- 4. UI --- | |
| with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
| gr.Markdown(f"# Number-E Demo") | |
| gr.Markdown("Generating numbers using a custom GAN architecture.") | |
| with gr.Row(): | |
| qty = gr.Slider(1, 100, value=10, label="Quantity", step=1) | |
| dig = gr.Slider(1, 10, value=2, label="Digits", step=1) | |
| btn = gr.Button("Generate", variant="primary") | |
| out = gr.Code(label="Output") | |
| btn.click(generate, inputs=[qty, dig], outputs=out) | |
| demo.launch() |