g-ssm-xor / app.py
joaquinsturtz's picture
Docs: Correct parameter count to verified 3,164
bd7efac verified
import os
# Force HF Space environment detection for Gradio
os.environ["SPACE_ID"] = "DepthMuun/g-ssm-xor"
os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0"
os.environ["GRADIO_SERVER_PORT"] = "7860"
import gradio as gr
import torch
import math
import sys
import json
from pathlib import Path
# Add local gfn folder to path if it exists (for HF Spaces)
script_dir = os.path.dirname(os.path.abspath(__file__))
if os.path.exists(os.path.join(script_dir, "gfn")):
sys.path.insert(0, script_dir)
import gfn
def load_model():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load config safely using absolute path
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "config.json")
with open(config_path, "r") as f:
config = json.load(f)
model = gfn.gssm.create(
vocab_size=config['architecture']['vocab_size'],
dim=config['architecture']['dim'],
depth=config['architecture']['depth'],
heads=config['architecture']['heads'],
physics=config['physics'],
trajectory_mode=config['architecture']['trajectory_mode'],
coupler_mode=config['architecture']['coupler_mode'],
initial_spread=config['architecture']['initial_spread'],
integrator=config['architecture']['integrator'],
holographic=config['architecture'].get('holographic', True)
).to(device)
checkpoint_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "xor_best_model.bin")
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
model.eval()
return model, device
model, device = load_model()
import json
import tempfile
def predict_parity(bitstream):
if not all(c in "01" for c in bitstream):
return "Error: Input must be a binary string.", 0, None
if len(bitstream) == 0:
return "Empty input", 0, None
x_in = torch.tensor([[int(c) for c in bitstream]], device=device)
with torch.no_grad():
output = model(x_in)
x_pred = output[0] # [B, T, D]
# Parity calculation for display
bits = [int(c) for c in bitstream]
cumulative_parity = []
curr = 0
for b in bits:
curr = curr ^ b
cumulative_parity.append(int(curr))
# Prediction
PI = math.pi
TWO_PI = 2.0 * PI
half_pi = PI * 0.5
# Last token prediction
final_state = x_pred[0, -1, :]
dist_pos = torch.min(
torch.abs(final_state - half_pi) % TWO_PI,
TWO_PI - (torch.abs(final_state - half_pi) % TWO_PI)
).mean().item()
dist_neg = torch.min(
torch.abs(final_state + half_pi) % TWO_PI,
TWO_PI - (torch.abs(final_state + half_pi) % TWO_PI)
).mean().item()
prediction = 1 if dist_pos < dist_neg else 0
is_correct = (prediction == cumulative_parity[-1])
accuracy = 100.0 if is_correct else 0.0
confidence = 1.0 - min(dist_pos, dist_neg) / half_pi
result_data = {
"input": bitstream,
"target_parity": cumulative_parity[-1],
"model_prediction": prediction,
"is_correct": is_correct,
"geometric_confidence": f"{confidence:.4f}",
"sequence_length": len(bitstream),
"full_target_trace": "".join(map(str, cumulative_parity))
}
# Save to temp file for download
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w')
json.dump(result_data, temp_file, indent=4)
temp_file.close()
status = "βœ… SUCCESS" if is_correct else "❌ FAILURE"
return status, f"{accuracy}% Accuracy"
with gr.Blocks(title="G-SSM XOR Parity Solver", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸŒ€ G-SSM XOR Parity Solver")
gr.Markdown("""
**Params: 3,164** | **Memory: O(1) Constant** | **Architecture: G-SSM (Geodesic State Space Model)**
This model demonstrates **topological logic generalization**. It solves the binary parity (XOR) problem by integrating bit-impulses as physical forces on a 1D manifold. The state evolves via geodetic flow, accumulating phase shifts that represent the cumulative parity without requiring self-attention or reprocessing past context.
""")
with gr.Row():
with gr.Column(scale=2):
input_text = gr.Textbox(
label="Input Binary Stream",
placeholder="Enter 0s and 1s...",
value="10110",
lines=2
)
submit_btn = gr.Button("πŸ”₯ Run Geometric Inference", variant="primary")
with gr.Row():
status_output = gr.Textbox(label="Status")
acc_label = gr.Textbox(label="Geometric Accuracy Metric")
# LINK EVENTS
submit_btn.click(
fn=predict_parity,
inputs=input_text,
outputs=[status_output, acc_label]
)
input_text.submit(
fn=predict_parity,
inputs=input_text,
outputs=[status_output, acc_label]
)
if __name__ == "__main__":
demo.queue().launch(show_api=False)