"""
TorchForge Demo — Hugging Face Space
Enterprise-grade PyTorch framework with governance, monitoring, and deployment.
"""
import gradio as gr
import json
import sys
# ---------------------------------------------------------------------------
# Safe imports — some heavy deps (torch) may not be present in the Space env
# ---------------------------------------------------------------------------
try:
from torchforge.core.config import (
ForgeConfig, MonitoringConfig, GovernanceConfig,
OptimizationConfig, DeploymentConfig,
)
from torchforge.governance.compliance import ComplianceChecker
from torchforge.monitoring.metrics import MetricsCollector
from torchforge.monitoring.monitor import ModelMonitor
from torchforge.deployment.manager import DeploymentManager, DeploymentMetrics
TORCHFORGE_AVAILABLE = True
except Exception as e:
TORCHFORGE_AVAILABLE = False
IMPORT_ERROR = str(e)
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _unavailable():
return "⚠️ TorchForge could not be imported in this environment. Check Space logs."
# ---------------------------------------------------------------------------
# Tab 1 — Compliance Checker
# ---------------------------------------------------------------------------
def run_compliance(
model_name, model_version,
has_governance, has_risk_map, has_impact_assessment,
has_risk_mgmt, has_transparency, has_fairness, has_security,
):
if not TORCHFORGE_AVAILABLE:
return _unavailable(), ""
metadata = {
"model_name": model_name or "my-model",
"version": model_version or "1.0.0",
"has_governance_policy": has_governance,
"has_risk_mapping": has_risk_map,
"has_impact_assessment": has_impact_assessment,
"has_risk_management": has_risk_mgmt,
"has_transparency_docs": has_transparency,
"has_fairness_evaluation": has_fairness,
"has_security_assessment": has_security,
}
checker = ComplianceChecker()
report = checker.assess_compliance(metadata)
summary = (
f"**Model:** {report.model_name} v{report.model_version}\n\n"
f"**Overall Score:** {report.overall_score:.1f} / 100\n\n"
f"**Risk Level:** {report.risk_level.upper()}\n\n"
f"**Checks Passed:** {report.passed_checks} / {report.total_checks}\n\n"
"---\n\n**Recommendations:**\n\n"
+ "\n".join(f"- {r}" for r in report.recommendations[:8])
)
# Build a simple HTML table of results
rows = "".join(
f"
| {r.check_name} | "
f""
f"{'✔ Pass' if r.passed else '✘ Fail'} | "
f"{r.score:.0f} | "
f"{r.details} |
"
for r in report.results
)
html = f"""
| Check | Status | Score | Details |
{rows}
"""
return summary, html
# ---------------------------------------------------------------------------
# Tab 2 — Configuration Builder
# ---------------------------------------------------------------------------
def build_config(
model_name, version, env,
enable_monitoring, enable_drift, enable_fairness,
enable_governance, enable_audit, enable_bias,
enable_profiling, enable_quantization,
deploy_target,
):
if not TORCHFORGE_AVAILABLE:
return _unavailable()
cfg = ForgeConfig(
model_name=model_name or "my-model",
version=version or "1.0.0",
environment=env,
monitoring=MonitoringConfig(
enabled=enable_monitoring,
drift_detection=enable_drift,
fairness_tracking=enable_fairness,
),
governance=GovernanceConfig(
audit_logging=enable_audit,
bias_detection=enable_bias,
),
optimization=OptimizationConfig(
profiling_enabled=enable_profiling,
quantization_enabled=enable_quantization,
),
deployment=DeploymentConfig(
target=deploy_target,
),
)
return f"```json\n{json.dumps(cfg.to_dict(), indent=2)}\n```"
# ---------------------------------------------------------------------------
# Tab 3 — Metrics Simulator
# ---------------------------------------------------------------------------
import random, time
def simulate_metrics(n_inferences, error_rate_pct):
if not TORCHFORGE_AVAILABLE:
return _unavailable()
collector = MetricsCollector(window_size=max(n_inferences, 10))
error_rate = error_rate_pct / 100.0
for _ in range(n_inferences):
latency = random.gauss(0.05, 0.01) # ~50 ms avg
collector.record_inference(max(0.001, latency))
if random.random() < error_rate:
collector.record_error()
stats = collector.get_stats()
monitor = ModelMonitor("demo-model", "1.0.0")
health = monitor.get_health_status()
out = (
f"**Inferences recorded:** {stats.get('inference_count', 0)}\n\n"
f"**Error count:** {stats.get('error_count', 0)}\n\n"
f"**Error rate:** {stats.get('error_rate', 0):.1%}\n\n"
f"**Mean latency:** {stats.get('mean_latency', 0)*1000:.1f} ms\n\n"
f"**p95 latency:** {stats.get('p95_latency', 0)*1000:.1f} ms\n\n"
f"**p99 latency:** {stats.get('p99_latency', 0)*1000:.1f} ms\n\n"
f"**Health status:** {health.get('status', 'unknown').upper()}\n\n"
f"**Uptime:** {stats.get('uptime_seconds', 0):.1f}s"
)
return out
# ---------------------------------------------------------------------------
# Tab 4 — Deployment Simulator
# ---------------------------------------------------------------------------
def simulate_deployment(model_name, version, target, min_inst, max_inst):
if not TORCHFORGE_AVAILABLE:
return _unavailable()
cfg = ForgeConfig(
model_name=model_name or "my-model",
version=version or "1.0.0",
deployment=DeploymentConfig(
target=target,
min_instances=min_inst,
max_instances=max_inst,
),
)
manager = DeploymentManager(cfg)
result = manager.deploy()
metrics = manager.get_metrics()
out = (
f"**Status:** {result.get('status', 'unknown').upper()}\n\n"
f"**Endpoint:** `{result.get('endpoint', 'N/A')}`\n\n"
f"**Target:** {target}\n\n"
f"**Instances:** {min_inst} – {max_inst}\n\n"
"---\n\n**Simulated Metrics:**\n\n"
f"- p95 latency: {metrics.p95_latency_ms:.1f} ms\n"
f"- p99 latency: {metrics.p99_latency_ms:.1f} ms\n"
f"- Requests/sec: {metrics.requests_per_second:.1f}\n"
f"- Error rate: {metrics.error_rate:.2%}\n"
f"- Active instances: {metrics.active_instances}"
)
return out
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
ABOUT_MD = """
# TorchForge — Enterprise PyTorch Framework
**TorchForge** wraps your PyTorch models with production-grade capabilities:
| Feature | Description |
|---|---|
| 🏛️ Governance | NIST AI RMF compliance, audit logging, bias detection |
| 📊 Monitoring | Drift detection, fairness tracking, latency metrics |
| 🚀 Deployment | Multi-cloud (AWS / Azure / GCP / K8s / Docker) |
| ⚡ Optimization | Profiling, quantization, ONNX export |
| 🔐 Security | Provenance tracking, model lineage |
### Install
```bash
pip install pytorchforge
```
### Quick start
```python
import torch.nn as nn
from torchforge import ForgeModel, ForgeConfig
config = ForgeConfig(model_name="my-classifier", version="1.0.0")
model = ForgeModel(nn.Linear(128, 10), config)
output = model(input_tensor) # forward pass + auto-metrics
model.save_checkpoint("model.pt")
```
**Links:** [PyPI](https://pypi.org/project/pytorchforge/) · [GitHub](https://github.com/anilatambharii/torchforge)
"""
with gr.Blocks(title="TorchForge Demo", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔥 TorchForge — Enterprise PyTorch Framework Demo")
with gr.Tabs():
# --- About ---
with gr.Tab("About"):
gr.Markdown(ABOUT_MD)
# --- Compliance ---
with gr.Tab("Compliance Checker"):
gr.Markdown("### NIST AI Risk Management Framework Assessment")
with gr.Row():
with gr.Column():
c_name = gr.Textbox(label="Model Name", value="my-classifier")
c_version = gr.Textbox(label="Version", value="1.0.0")
gr.Markdown("**Which controls does your model have?**")
c_gov = gr.Checkbox(label="Governance policy", value=True)
c_risk = gr.Checkbox(label="Risk mapping")
c_imp = gr.Checkbox(label="Impact assessment")
c_mgmt = gr.Checkbox(label="Risk management plan")
c_tran = gr.Checkbox(label="Transparency documentation")
c_fair = gr.Checkbox(label="Fairness evaluation")
c_sec = gr.Checkbox(label="Security assessment")
c_btn = gr.Button("Run Assessment", variant="primary")
with gr.Column():
c_summary = gr.Markdown(label="Summary")
c_table = gr.HTML(label="Detailed Results")
c_btn.click(
run_compliance,
inputs=[c_name, c_version, c_gov, c_risk, c_imp, c_mgmt, c_tran, c_fair, c_sec],
outputs=[c_summary, c_table],
)
# --- Config Builder ---
with gr.Tab("Config Builder"):
gr.Markdown("### Generate a ForgeConfig for your model")
with gr.Row():
with gr.Column():
cfg_name = gr.Textbox(label="Model Name", value="my-model")
cfg_version = gr.Textbox(label="Version", value="1.0.0")
cfg_env = gr.Dropdown(["development","staging","production"], label="Environment", value="production")
cfg_target = gr.Dropdown(["local","docker","kubernetes","aws","azure","gcp"], label="Deploy Target", value="aws")
gr.Markdown("**Monitoring**")
cfg_mon = gr.Checkbox(label="Enable monitoring", value=True)
cfg_drift = gr.Checkbox(label="Drift detection", value=True)
cfg_fair = gr.Checkbox(label="Fairness tracking")
gr.Markdown("**Governance**")
cfg_audit = gr.Checkbox(label="Audit logging", value=True)
cfg_bias = gr.Checkbox(label="Bias detection")
gr.Markdown("**Optimization**")
cfg_prof = gr.Checkbox(label="Profiling")
cfg_quant = gr.Checkbox(label="Quantization")
cfg_btn = gr.Button("Generate Config", variant="primary")
with gr.Column():
cfg_out = gr.Markdown(label="Generated Config (JSON)")
cfg_btn.click(
build_config,
inputs=[cfg_name, cfg_version, cfg_env, cfg_mon, cfg_drift, cfg_fair,
cfg_audit, cfg_bias, cfg_prof, cfg_quant, cfg_target],
outputs=cfg_out,
)
# --- Metrics ---
with gr.Tab("Metrics Simulator"):
gr.Markdown("### Simulate model inference metrics")
with gr.Row():
with gr.Column():
m_n = gr.Slider(10, 500, value=100, step=10, label="Number of inferences")
m_err = gr.Slider(0, 20, value=2, step=1, label="Error rate (%)")
m_btn = gr.Button("Simulate", variant="primary")
with gr.Column():
m_out = gr.Markdown()
m_btn.click(simulate_metrics, inputs=[m_n, m_err], outputs=m_out)
# --- Deployment ---
with gr.Tab("Deployment Simulator"):
gr.Markdown("### Simulate a cloud deployment")
with gr.Row():
with gr.Column():
d_name = gr.Textbox(label="Model Name", value="my-model")
d_ver = gr.Textbox(label="Version", value="1.0.0")
d_target = gr.Dropdown(["local","docker","kubernetes","aws","azure","gcp"], label="Target", value="aws")
d_min = gr.Slider(1, 5, value=2, step=1, label="Min instances")
d_max = gr.Slider(2, 20, value=5, step=1, label="Max instances")
d_btn = gr.Button("Deploy", variant="primary")
with gr.Column():
d_out = gr.Markdown()
d_btn.click(simulate_deployment, inputs=[d_name, d_ver, d_target, d_min, d_max], outputs=d_out)
if __name__ == "__main__":
demo.launch()