|
|
"""dcode - Text to Polargraph Gcode via Stable Diffusion""" |
|
|
|
|
|
import re |
|
|
import os |
|
|
import json |
|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from pathlib import Path |
|
|
import spaces |
|
|
|
|
|
|
|
|
BOUNDS = {"left": -420.5, "right": 420.5, "top": 594.5, "bottom": -594.5} |
|
|
|
|
|
|
|
|
_model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GcodeDecoderConfigV3: |
|
|
"""Config for v3 decoder architecture.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
latent_channels: int = 4, |
|
|
latent_size: int = 64, |
|
|
hidden_size: int = 1024, |
|
|
num_layers: int = 12, |
|
|
num_heads: int = 16, |
|
|
vocab_size: int = 8192, |
|
|
max_seq_len: int = 2048, |
|
|
dropout: float = 0.1, |
|
|
ffn_mult: int = 4, |
|
|
): |
|
|
self.latent_channels = latent_channels |
|
|
self.latent_size = latent_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_layers = num_layers |
|
|
self.num_heads = num_heads |
|
|
self.vocab_size = vocab_size |
|
|
self.max_seq_len = max_seq_len |
|
|
self.dropout = dropout |
|
|
self.ffn_mult = ffn_mult |
|
|
|
|
|
|
|
|
class CNNLatentProjector(nn.Module): |
|
|
"""CNN-based latent projector preserving spatial structure.""" |
|
|
|
|
|
def __init__(self, config: GcodeDecoderConfigV3): |
|
|
super().__init__() |
|
|
|
|
|
self.cnn = nn.Sequential( |
|
|
nn.Conv2d(config.latent_channels, 64, 3, stride=2, padding=1), |
|
|
nn.LayerNorm([64, 32, 32]), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(64, 128, 3, stride=2, padding=1), |
|
|
nn.LayerNorm([128, 16, 16]), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(128, 256, 3, stride=2, padding=1), |
|
|
nn.LayerNorm([256, 8, 8]), |
|
|
nn.GELU(), |
|
|
nn.Conv2d(256, config.hidden_size, 3, stride=2, padding=1), |
|
|
nn.LayerNorm([config.hidden_size, 4, 4]), |
|
|
nn.GELU(), |
|
|
) |
|
|
|
|
|
self.num_memory_tokens = 16 |
|
|
self.memory_pos = nn.Parameter(torch.randn(1, self.num_memory_tokens, config.hidden_size) * 0.02) |
|
|
|
|
|
def forward(self, latent: torch.Tensor) -> torch.Tensor: |
|
|
B = latent.shape[0] |
|
|
x = self.cnn(latent) |
|
|
x = x.view(B, x.shape[1], -1).transpose(1, 2) |
|
|
x = x + self.memory_pos.expand(B, -1, -1) |
|
|
return x |
|
|
|
|
|
|
|
|
class GcodeDecoderV3(nn.Module): |
|
|
"""Large transformer decoder for gcode generation (v3).""" |
|
|
|
|
|
def __init__(self, config: GcodeDecoderConfigV3): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.latent_proj = CNNLatentProjector(config) |
|
|
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size) |
|
|
self.embed_drop = nn.Dropout(config.dropout) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_heads, |
|
|
dim_feedforward=config.hidden_size * config.ffn_mult, |
|
|
dropout=config.dropout, |
|
|
activation='gelu', |
|
|
batch_first=True, |
|
|
norm_first=True, |
|
|
) |
|
|
for _ in range(config.num_layers) |
|
|
]) |
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.hidden_size) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
B, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
dtype = latent.dtype |
|
|
|
|
|
memory = self.latent_proj(latent) |
|
|
positions = torch.arange(seq_len, device=device) |
|
|
x = self.token_embed(input_ids) + self.pos_embed(positions) |
|
|
x = self.embed_drop(x) |
|
|
|
|
|
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype) |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, memory, tgt_mask=causal_mask) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
return self.lm_head(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GcodeDecoderConfigV2: |
|
|
def __init__( |
|
|
self, |
|
|
latent_channels: int = 4, |
|
|
latent_size: int = 64, |
|
|
hidden_size: int = 768, |
|
|
num_layers: int = 6, |
|
|
num_heads: int = 12, |
|
|
vocab_size: int = 32128, |
|
|
max_seq_len: int = 1024, |
|
|
dropout: float = 0.1, |
|
|
): |
|
|
self.latent_channels = latent_channels |
|
|
self.latent_size = latent_size |
|
|
self.latent_dim = latent_channels * latent_size * latent_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_layers = num_layers |
|
|
self.num_heads = num_heads |
|
|
self.vocab_size = vocab_size |
|
|
self.max_seq_len = max_seq_len |
|
|
self.dropout = dropout |
|
|
|
|
|
|
|
|
class GcodeDecoderV2(nn.Module): |
|
|
def __init__(self, config: GcodeDecoderConfigV2): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.latent_proj = nn.Sequential( |
|
|
nn.Linear(config.latent_dim, config.hidden_size * 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(config.hidden_size * 4, config.hidden_size * 16), |
|
|
nn.LayerNorm(config.hidden_size * 16), |
|
|
) |
|
|
|
|
|
self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.pos_embed = nn.Embedding(config.max_seq_len, config.hidden_size) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.hidden_size, |
|
|
nhead=config.num_heads, |
|
|
dim_feedforward=config.hidden_size * 4, |
|
|
dropout=config.dropout, |
|
|
activation='gelu', |
|
|
batch_first=True, |
|
|
norm_first=True, |
|
|
) |
|
|
for _ in range(config.num_layers) |
|
|
]) |
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.hidden_size) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.lm_head.weight = self.token_embed.weight |
|
|
|
|
|
def forward(self, latent: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
batch_size, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
dtype = latent.dtype |
|
|
|
|
|
latent_flat = latent.view(batch_size, -1) |
|
|
memory = self.latent_proj(latent_flat) |
|
|
memory = memory.view(batch_size, 16, self.config.hidden_size) |
|
|
|
|
|
positions = torch.arange(seq_len, device=device) |
|
|
x = self.token_embed(input_ids) + self.pos_embed(positions) |
|
|
|
|
|
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=device, dtype=dtype) |
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x, memory, tgt_mask=causal_mask) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
return self.lm_head(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(): |
|
|
"""Load and cache the SD-Gcode model.""" |
|
|
global _model |
|
|
if _model is None: |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from transformers import AutoTokenizer, PreTrainedTokenizerFast |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
print("Loading SD-Gcode model...") |
|
|
|
|
|
|
|
|
config_path = hf_hub_download("twarner/dcode-sd-gcode", "config.json") |
|
|
weights_path = hf_hub_download("twarner/dcode-sd-gcode", "pytorch_model.bin") |
|
|
|
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
gcode_cfg = config.get("gcode_decoder", {}) |
|
|
is_v3 = gcode_cfg.get("ffn_mult") is not None or gcode_cfg.get("hidden_size", 768) >= 1024 |
|
|
|
|
|
print(f"Model version: {'v3' if is_v3 else 'v2'}") |
|
|
|
|
|
|
|
|
sd_model_id = config.get("sd_model_id", "runwayml/stable-diffusion-v1-5") |
|
|
print(f"Loading SD from {sd_model_id}...") |
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
sd_model_id, |
|
|
torch_dtype=dtype, |
|
|
safety_checker=None, |
|
|
).to(device) |
|
|
|
|
|
|
|
|
if is_v3: |
|
|
decoder_config = GcodeDecoderConfigV3( |
|
|
latent_channels=gcode_cfg.get("latent_channels", 4), |
|
|
latent_size=gcode_cfg.get("latent_size", 64), |
|
|
hidden_size=gcode_cfg.get("hidden_size", 1024), |
|
|
num_layers=gcode_cfg.get("num_layers", 12), |
|
|
num_heads=gcode_cfg.get("num_heads", 16), |
|
|
vocab_size=gcode_cfg.get("vocab_size", 8192), |
|
|
max_seq_len=gcode_cfg.get("max_seq_len", 2048), |
|
|
ffn_mult=gcode_cfg.get("ffn_mult", 4), |
|
|
) |
|
|
gcode_decoder = GcodeDecoderV3(decoder_config).to(device, dtype) |
|
|
else: |
|
|
decoder_config = GcodeDecoderConfigV2( |
|
|
latent_channels=gcode_cfg.get("latent_channels", 4), |
|
|
latent_size=gcode_cfg.get("latent_size", 64), |
|
|
hidden_size=gcode_cfg.get("hidden_size", 768), |
|
|
num_layers=gcode_cfg.get("num_layers", 6), |
|
|
num_heads=gcode_cfg.get("num_heads", 12), |
|
|
vocab_size=gcode_cfg.get("vocab_size", 32128), |
|
|
max_seq_len=gcode_cfg.get("max_seq_len", 1024), |
|
|
) |
|
|
gcode_decoder = GcodeDecoderV2(decoder_config).to(device, dtype) |
|
|
|
|
|
|
|
|
print("Loading finetuned weights...") |
|
|
state_dict = torch.load(weights_path, map_location=device, weights_only=False) |
|
|
|
|
|
|
|
|
text_encoder_state = {k.replace("text_encoder.", ""): v for k, v in state_dict.items() |
|
|
if k.startswith("text_encoder.")} |
|
|
if text_encoder_state: |
|
|
pipe.text_encoder.load_state_dict(text_encoder_state, strict=False) |
|
|
print(f"Loaded {len(text_encoder_state)} text encoder weights") |
|
|
|
|
|
unet_state = {k.replace("unet.", ""): v for k, v in state_dict.items() |
|
|
if k.startswith("unet.")} |
|
|
if unet_state: |
|
|
pipe.unet.load_state_dict(unet_state, strict=False) |
|
|
print(f"Loaded {len(unet_state)} UNet weights") |
|
|
|
|
|
|
|
|
decoder_state = {k.replace("gcode_decoder.", ""): v for k, v in state_dict.items() |
|
|
if k.startswith("gcode_decoder.")} |
|
|
if decoder_state: |
|
|
try: |
|
|
gcode_decoder.load_state_dict(decoder_state, strict=True) |
|
|
print(f"Loaded {len(decoder_state)} decoder weights (strict)") |
|
|
except Exception as e: |
|
|
print(f"Strict load failed: {e}") |
|
|
gcode_decoder.load_state_dict(decoder_state, strict=False) |
|
|
print(f"Loaded {len(decoder_state)} decoder weights (non-strict)") |
|
|
|
|
|
gcode_decoder.eval() |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
tokenizer_path = hf_hub_download("twarner/dcode-sd-gcode", "gcode_tokenizer/tokenizer.json") |
|
|
gcode_tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) |
|
|
print("Loaded custom gcode tokenizer") |
|
|
except Exception: |
|
|
|
|
|
gcode_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base") |
|
|
print("Using fallback T5 tokenizer") |
|
|
|
|
|
_model = { |
|
|
"pipe": pipe, |
|
|
"gcode_decoder": gcode_decoder, |
|
|
"gcode_tokenizer": gcode_tokenizer, |
|
|
"device": device, |
|
|
"dtype": dtype, |
|
|
"num_inference_steps": config.get("num_inference_steps", 20), |
|
|
"is_v3": is_v3, |
|
|
} |
|
|
print("Model loaded!") |
|
|
|
|
|
return _model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate_gcode(gcode: str) -> str: |
|
|
"""Clamp coordinates to machine bounds.""" |
|
|
lines = [] |
|
|
for line in gcode.split("\n"): |
|
|
corrected = line |
|
|
|
|
|
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
|
|
if x_match: |
|
|
try: |
|
|
x = float(x_match.group(1)) |
|
|
x = max(BOUNDS["left"], min(BOUNDS["right"], x)) |
|
|
corrected = re.sub(r"X[-\d.]+", f"X{x:.2f}", corrected, flags=re.IGNORECASE) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
|
|
if y_match: |
|
|
try: |
|
|
y = float(y_match.group(1)) |
|
|
y = max(BOUNDS["bottom"], min(BOUNDS["top"], y)) |
|
|
corrected = re.sub(r"Y[-\d.]+", f"Y{y:.2f}", corrected, flags=re.IGNORECASE) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
lines.append(corrected) |
|
|
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
def gcode_to_svg(gcode: str) -> str: |
|
|
"""Convert gcode to SVG for visual preview.""" |
|
|
paths = [] |
|
|
current_path = [] |
|
|
x, y = 0.0, 0.0 |
|
|
pen_down = False |
|
|
|
|
|
|
|
|
lines = [] |
|
|
|
|
|
gcode = gcode.replace("<newline>", "\n") |
|
|
|
|
|
for line in gcode.replace(";", "\n;").split("\n"): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
parts = re.split(r'(?=[GM]\d)', line) |
|
|
for part in parts: |
|
|
part = part.strip() |
|
|
if part and not part.startswith(";"): |
|
|
lines.append(part) |
|
|
|
|
|
for line in lines: |
|
|
if "M280" in line.upper(): |
|
|
match = re.search(r"S(\d+)", line, re.IGNORECASE) |
|
|
if match: |
|
|
angle = int(match.group(1)) |
|
|
was_down = pen_down |
|
|
pen_down = angle < 50 |
|
|
if was_down and not pen_down and len(current_path) > 1: |
|
|
paths.append(current_path[:]) |
|
|
current_path = [] |
|
|
|
|
|
x_match = re.search(r"X([-\d.]+)", line, re.IGNORECASE) |
|
|
y_match = re.search(r"Y([-\d.]+)", line, re.IGNORECASE) |
|
|
|
|
|
if x_match: |
|
|
try: |
|
|
x = float(x_match.group(1)) |
|
|
except ValueError: |
|
|
pass |
|
|
if y_match: |
|
|
try: |
|
|
y = float(y_match.group(1)) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
if (x_match or y_match) and pen_down: |
|
|
current_path.append((x, y)) |
|
|
|
|
|
if len(current_path) > 1: |
|
|
paths.append(current_path) |
|
|
|
|
|
w = BOUNDS["right"] - BOUNDS["left"] |
|
|
h = BOUNDS["top"] - BOUNDS["bottom"] |
|
|
padding = 20 |
|
|
|
|
|
svg = f'''<svg xmlns="http://www.w3.org/2000/svg" |
|
|
viewBox="{BOUNDS["left"] - padding} {-BOUNDS["top"] - padding} {w + 2*padding} {h + 2*padding}" |
|
|
style="width: 100%; height: 480px; border: 1px solid var(--border, #e0e0e0); border-radius: 4px;"> |
|
|
<style> |
|
|
@media (prefers-color-scheme: dark) {{ |
|
|
.bg {{ fill: #2a2b30; }} |
|
|
.work {{ fill: #212226; stroke: #3a3b40; }} |
|
|
.stroke {{ stroke: #e8e8e8; }} |
|
|
.label {{ fill: #666; }} |
|
|
}} |
|
|
@media (prefers-color-scheme: light) {{ |
|
|
.bg {{ fill: #fff; }} |
|
|
.work {{ fill: #fafafa; stroke: #ccc; }} |
|
|
.stroke {{ stroke: #1a1a1a; }} |
|
|
.label {{ fill: #999; }} |
|
|
}} |
|
|
</style> |
|
|
<rect class="bg" x="{BOUNDS["left"] - padding}" y="{-BOUNDS["top"] - padding}" width="{w + 2*padding}" height="{h + 2*padding}"/> |
|
|
<rect class="work" x="{BOUNDS["left"]}" y="{-BOUNDS["top"]}" width="{w}" height="{h}" stroke-width="1"/> |
|
|
''' |
|
|
|
|
|
for path in paths: |
|
|
if len(path) < 2: |
|
|
continue |
|
|
d = " ".join(f"{'M' if i == 0 else 'L'}{p[0]:.1f},{-p[1]:.1f}" for i, p in enumerate(path)) |
|
|
svg += f'<path class="stroke" d="{d}" fill="none" stroke-width="1" stroke-linecap="round" stroke-linejoin="round"/>' |
|
|
|
|
|
total_points = sum(len(p) for p in paths) |
|
|
svg += f''' |
|
|
<text class="label" x="{BOUNDS["left"] + 8}" y="{-BOUNDS["top"] + 20}" font-family="monospace" font-size="12"> |
|
|
{len(paths)} paths / {total_points} points |
|
|
</text> |
|
|
''' |
|
|
svg += "</svg>" |
|
|
return svg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def generate(prompt: str, temperature: float, max_tokens: int, num_steps: int, guidance: float): |
|
|
"""Generate gcode from text prompt.""" |
|
|
if not prompt or not prompt.strip(): |
|
|
return "Enter a prompt to generate gcode", gcode_to_svg("") |
|
|
|
|
|
try: |
|
|
m = get_model() |
|
|
pipe = m["pipe"] |
|
|
gcode_decoder = m["gcode_decoder"] |
|
|
gcode_tokenizer = m["gcode_tokenizer"] |
|
|
device = m["device"] |
|
|
dtype = m["dtype"] |
|
|
is_v3 = m.get("is_v3", False) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
result = pipe( |
|
|
prompt, |
|
|
num_inference_steps=num_steps, |
|
|
guidance_scale=guidance, |
|
|
output_type="latent", |
|
|
) |
|
|
latent = result.images.to(dtype) |
|
|
print(f"Latent shape: {latent.shape}, dtype: {latent.dtype}") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_size = latent.shape[0] |
|
|
|
|
|
|
|
|
if is_v3: |
|
|
|
|
|
start_id = gcode_tokenizer.bos_token_id or 0 |
|
|
else: |
|
|
|
|
|
start_tokens = gcode_tokenizer.encode(";", add_special_tokens=False) |
|
|
start_id = start_tokens[0] if start_tokens else gcode_tokenizer.pad_token_id |
|
|
|
|
|
input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device) |
|
|
|
|
|
max_gen = min(max_tokens, gcode_decoder.config.max_seq_len - 1) |
|
|
|
|
|
for step in range(max_gen): |
|
|
logits = gcode_decoder(latent, input_ids) |
|
|
next_logits = logits[:, -1, :] / temperature |
|
|
|
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
sorted_indices_to_remove = cumulative_probs > 0.9 |
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
|
|
sorted_indices_to_remove[:, 0] = False |
|
|
|
|
|
for b in range(batch_size): |
|
|
next_logits[b, sorted_indices[b, sorted_indices_to_remove[b]]] = float('-inf') |
|
|
|
|
|
probs = torch.softmax(next_logits, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
|
|
|
|
|
|
if next_token.item() == gcode_tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
print(f"Generated {input_ids.shape[1]} tokens") |
|
|
gcode = gcode_tokenizer.decode(input_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if is_v3: |
|
|
gcode = gcode.replace("<newline>", "\n") |
|
|
|
|
|
print(f"Decoded gcode length: {len(gcode)} chars") |
|
|
|
|
|
gcode = validate_gcode(gcode) |
|
|
line_count = len([l for l in gcode.split("\n") if l.strip()]) |
|
|
svg = gcode_to_svg(gcode) |
|
|
|
|
|
header = f"; dcode output\n; prompt: {prompt}\n; {line_count} commands\n\n" |
|
|
return header + gcode, svg |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"; Error: {e}", gcode_to_svg("") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = """ |
|
|
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@400;500&display=swap'); |
|
|
|
|
|
:root { |
|
|
--bg: #ffffff; |
|
|
--bg-secondary: #fafafa; |
|
|
--text: #1a1a1a; |
|
|
--text-secondary: #666; |
|
|
--border: #e0e0e0; |
|
|
--btn-bg: #f0f0f0; |
|
|
--btn-hover: #e0e0e0; |
|
|
} |
|
|
|
|
|
@media (prefers-color-scheme: dark) { |
|
|
:root { |
|
|
--bg: #212226; |
|
|
--bg-secondary: #2a2b30; |
|
|
--text: #e8e8e8; |
|
|
--text-secondary: #999; |
|
|
--border: #3a3b40; |
|
|
--btn-bg: #3a3b40; |
|
|
--btn-hover: #4a4b50; |
|
|
} |
|
|
} |
|
|
|
|
|
* { |
|
|
font-family: 'IBM Plex Mono', monospace !important; |
|
|
} |
|
|
|
|
|
body, .gradio-container { |
|
|
background: var(--bg) !important; |
|
|
color: var(--text) !important; |
|
|
} |
|
|
|
|
|
.gradio-container { |
|
|
max-width: 900px !important; |
|
|
margin: auto; |
|
|
} |
|
|
|
|
|
.gr-button { |
|
|
background: var(--btn-bg) !important; |
|
|
border: 1px solid var(--border) !important; |
|
|
color: var(--text) !important; |
|
|
font-weight: 500 !important; |
|
|
} |
|
|
|
|
|
.gr-button:hover { |
|
|
background: var(--btn-hover) !important; |
|
|
} |
|
|
|
|
|
.gr-examples { |
|
|
margin-top: 8px !important; |
|
|
} |
|
|
|
|
|
footer { |
|
|
display: none !important; |
|
|
} |
|
|
|
|
|
h1, h2, h3, p, span, label { |
|
|
color: var(--text) !important; |
|
|
} |
|
|
|
|
|
.gr-box, .gr-panel, .gr-form { |
|
|
background: var(--bg-secondary) !important; |
|
|
border: 1px solid var(--border) !important; |
|
|
border-radius: 4px !important; |
|
|
} |
|
|
|
|
|
input, textarea { |
|
|
background: var(--bg) !important; |
|
|
color: var(--text) !important; |
|
|
border: 1px solid var(--border) !important; |
|
|
border-radius: 4px !important; |
|
|
} |
|
|
|
|
|
.gr-accordion { |
|
|
background: var(--bg-secondary) !important; |
|
|
border: 1px solid var(--border) !important; |
|
|
} |
|
|
|
|
|
a { |
|
|
color: var(--text-secondary) !important; |
|
|
} |
|
|
|
|
|
a:hover { |
|
|
color: var(--text) !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: |
|
|
gr.Markdown("# dcode") |
|
|
gr.Markdown("text → polargraph gcode via stable diffusion") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
prompt = gr.Textbox( |
|
|
label="prompt", |
|
|
placeholder="describe what to draw...", |
|
|
lines=2, |
|
|
show_label=True, |
|
|
) |
|
|
|
|
|
with gr.Accordion("settings", open=False): |
|
|
temperature = gr.Slider(0.5, 1.5, value=0.8, label="temperature", step=0.1) |
|
|
max_tokens = gr.Slider(256, 2048, value=1024, step=256, label="max tokens") |
|
|
num_steps = gr.Slider(10, 50, value=20, step=5, label="diffusion steps") |
|
|
guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="guidance") |
|
|
|
|
|
generate_btn = gr.Button("generate", variant="secondary") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["a line drawing of a horse"], |
|
|
["portrait sketch"], |
|
|
["geometric shapes"], |
|
|
], |
|
|
inputs=prompt, |
|
|
label=None, |
|
|
examples_per_page=3, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
preview = gr.HTML(value=gcode_to_svg("")) |
|
|
|
|
|
with gr.Accordion("gcode", open=False): |
|
|
gcode_output = gr.Code(label=None, language=None, lines=12) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("machine: 841×1189mm / pen servo 40-90° / [github](https://github.com/Twarner491/dcode) / [model](https://huggingface.co/twarner/dcode-sd-gcode) / mit") |
|
|
|
|
|
generate_btn.click(generate, [prompt, temperature, max_tokens, num_steps, guidance], [gcode_output, preview]) |
|
|
prompt.submit(generate, [prompt, temperature, max_tokens, num_steps, guidance], [gcode_output, preview]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|