# app.py
import os
import tempfile
import pathlib
import hashlib
import gradio as gr
import numpy as np
from quread.engine import QuantumStateVector
from quread.exporters import to_openqasm2, to_qiskit, to_cirq, to_csv, csv_to_skill
from quread.llm_explain_openai import explain_with_gpt4o
from quread.circuit_diagram import draw_circuit_svg
from quread.cost_guard import allow_request
from quread.export_pdf import md_to_pdf
from quread.heatmap import make_metric_heatmap, HeatmapConfig
from quread.metrics import (
compute_metrics_from_csv,
to_metrics_csv,
MetricWeights,
MetricThresholds,
)
# --- Qubit cap (configurable) ---
DEFAULT_MAX_QUBITS = 16 # safe default for CPU Spaces; change if you want
MAX_QUBITS = int(os.getenv("QUREAD_MAX_QUBITS", DEFAULT_MAX_QUBITS))
# ---------- Helpers ----------
def _new_sim(n_qubits: int):
qc = QuantumStateVector(int(n_qubits))
last_counts = None
selected_gate = "H"
return qc, last_counts, selected_gate
def _probs_top(qc, n_qubits: int, k: int = 6):
probs = (np.abs(qc.state) ** 2)
top = sorted(
[(format(i, f"0{n_qubits}b"), float(p)) for i, p in enumerate(probs)],
key=lambda x: x[1],
reverse=True
)[:k]
return top
def _counts_str(counts):
if not counts:
return ""
return "\n".join([f"{k}: {v}" for k, v in counts.items()])
def _write_tmp(filename: str, content: str) -> str:
base = pathlib.Path(filename).name
stem = pathlib.Path(base).stem or "quread"
suffix = pathlib.Path(base).suffix or ".txt"
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
prefix=f"{stem}_",
suffix=suffix,
dir=tempfile.gettempdir(),
delete=False,
) as f:
f.write(content)
return f.name
def _circuit_hash(history):
return hashlib.sha256(str(history).encode()).hexdigest()
def dl_qasm(qc, n_qubits):
return _write_tmp("circuit.qasm", to_openqasm2(qc.history, n_qubits=int(n_qubits)))
def dl_qiskit(qc, n_qubits):
return _write_tmp("circuit_qiskit.py", to_qiskit(qc.history, n_qubits=int(n_qubits)))
def dl_cirq(qc, n_qubits):
return _write_tmp("circuit_cirq.py", to_cirq(qc.history, n_qubits=int(n_qubits)))
def dl_csv(qc):
return _write_tmp("circuit.csv", to_csv(qc.history))
def dl_skill(qc, n_qubits):
csv_text = to_csv(qc.history)
skill_text = csv_to_skill(csv_text, n_qubits=int(n_qubits))
return _write_tmp("circuit.il", skill_text)
def update_views(qc, last_counts, n_qubits):
svg = draw_circuit_svg(qc.history, n_qubits=int(n_qubits))
ket = qc.ket_notation(max_terms=16)
probs_top = _probs_top(qc, int(n_qubits), k=6)
counts_text = _counts_str(last_counts)
return svg, ket, counts_text, probs_top
def init_or_reset(n_qubits):
qc, last_counts, selected_gate = _new_sim(n_qubits)
return qc, last_counts, selected_gate, "✅ Reset done."
def set_gate(gate):
return gate, f"✅ Selected gate: {gate}"
def apply_selected_gate(qc, last_counts, selected_gate, target):
qc.apply_single(selected_gate, target=int(target))
last_counts = None
return qc, last_counts, f"✅ Applied {selected_gate} on q{target}."
def apply_cnot(qc, last_counts, control, target):
if int(control) == int(target):
return qc, last_counts, "❌ Control and target must be different."
qc.apply_cnot(control=int(control), target=int(target))
last_counts = None
return qc, last_counts, f"✅ Applied CNOT (q{control} -> q{target})."
def sample_shots(qc, shots):
last_counts = qc.sample(shots=int(shots))
return last_counts, "✅ Sampled shots."
def measure_collapse(qc, shots):
res = qc.measure_collapse()
last_counts = qc.sample(shots=int(shots))
return last_counts, f"✅ Collapsed to |{res}⟩ and sampled shots."
def explain_llm(qc, n_qubits, shots, last_hash, previous_explanation):
# circuit-change gating
curr_hash = _circuit_hash(qc.history)
if curr_hash == last_hash:
if previous_explanation:
shown = f"ℹ️ Circuit unchanged. Reusing previous explanation.\n\n{previous_explanation}"
return shown, last_hash, previous_explanation
return "ℹ️ Circuit unchanged. No previous explanation available.", last_hash, previous_explanation
# cost guard
EST_TOKENS = 900
if not allow_request(EST_TOKENS):
if previous_explanation:
shown = "🚫 Explanation disabled (daily token limit reached). Showing previous explanation.\n\n"
shown += previous_explanation
return shown, last_hash, previous_explanation
return "🚫 Explanation disabled (daily token limit reached).", last_hash, previous_explanation
state_ket = qc.ket_notation(max_terms=6)
probs_top = _probs_top(qc, int(n_qubits), k=6)
try:
explanation = explain_with_gpt4o(
n_qubits=int(n_qubits),
history=qc.history,
state_ket=state_ket,
probs_top=probs_top,
shots=int(shots),
)
except Exception as exc:
safe_msg = f"❌ Explanation request failed: {exc}"
if previous_explanation:
shown = f"{safe_msg}\n\nShowing previous explanation:\n\n{previous_explanation}"
return shown, last_hash, previous_explanation
return safe_msg, last_hash, previous_explanation
return explanation, curr_hash, explanation
def _refresh_choices(n):
opts = list(range(int(n)))
return (
gr.Dropdown(choices=opts, value=0),
gr.Dropdown(choices=opts, value=0),
gr.Dropdown(choices=opts, value=min(1, int(n) - 1)),
)
def _on_qubit_count_change(n):
qc, last_counts, selected_gate = _new_sim(n)
t, c, ct = _refresh_choices(n)
msg = f"✅ Reinitialized simulator with {int(n)} qubits."
return qc, last_counts, selected_gate, t, c, ct, msg
def _metric_controls_to_models(
activity_w,
gate_error_w,
readout_error_w,
decoherence_w,
fidelity_w,
warning_thr,
critical_thr,
):
weights = MetricWeights(
activity=float(activity_w),
gate_error=float(gate_error_w),
readout_error=float(readout_error_w),
decoherence=float(decoherence_w),
fidelity=float(fidelity_w),
)
thresholds = MetricThresholds(
warning=float(warning_thr),
critical=float(critical_thr),
)
return weights, thresholds
def _hotspot_rows(metrics, n_qubits, top_k):
rows = []
n = int(n_qubits)
for q in range(n):
risk = float(metrics["composite_risk"][q])
level = int(metrics["hotspot_level"][q])
status = "critical" if level == 2 else ("warning" if level == 1 else "ok")
rows.append(
[
q,
status,
round(risk, 6),
round(float(metrics["activity_count"][q]), 3),
round(float(metrics["gate_error"][q]), 6),
round(float(metrics["readout_error"][q]), 6),
round(float(metrics["decoherence_risk"][q]), 6),
round(float(metrics["fidelity"][q]), 6),
]
)
rows.sort(key=lambda x: x[2], reverse=True)
k = max(1, min(int(top_k), len(rows)))
return rows[:k]
# ---------- Styling ----------
CSS = """
#title h1 { font-size: 42px !important; margin-bottom: 6px; }
#subtitle { color: #6b7280; margin-top: 0px; }
.sidebar {
background: #f6f7fb;
border-right: 1px solid #e5e7eb;
border-radius: 14px;
padding: 18px 14px;
height: calc(100vh - 34px);
position: sticky;
top: 16px;
}
.card {
background: white;
border: 1px solid #e5e7eb;
border-radius: 14px;
padding: 14px;
}
.section-title { font-size: 22px; font-weight: 700; margin: 6px 0 10px; }
.small-note { color: #6b7280; font-size: 12px; }
"""
theme = gr.themes.Soft(radius_size="lg", text_size="md")
with gr.Blocks(theme=theme, css=CSS, title="Quread.ai — State Vector Studio") as demo:
qc_state = gr.State()
last_counts_state = gr.State()
selected_gate_state = gr.State()
explanation_md = gr.State("")
last_explained_hash = gr.State("")
with gr.Row():
# Sidebar
with gr.Column(scale=3, elem_classes=["sidebar"]):
gr.Markdown("### Simulator Settings")
n_qubits = gr.Slider(1, MAX_QUBITS, value=2, step=1, label="Number of qubits")
gr.Markdown(f"
Max qubits: {MAX_QUBITS} (set env var QUREAD_MAX_QUBITS to change)
")
shots = gr.Slider(128, 8192, value=1024, step=128, label="Shots")
gr.Markdown("---")
gr.Markdown("### Explanation (GPT-4o)")
gr.Markdown("Uses GPT-4o with cost guards.
")
reset_btn = gr.Button("Reset Simulator", variant="secondary")
# Main
with gr.Column(scale=9):
gr.Markdown("Quread.ai — State Vector Studio
")
gr.Markdown("Build circuits, visualize, export, and get teaching-ready explanations.
")
status = gr.Markdown("")
with gr.Group(elem_classes=["card"]):
gr.Markdown("Gate Palette
")
with gr.Row():
gate_H = gr.Button("H")
gate_T = gr.Button("T")
gate_Tdg = gr.Button("T†")
gate_X = gr.Button("X")
with gr.Row():
gate_Y = gr.Button("Y")
gate_Z = gr.Button("Z")
gate_sx = gr.Button("√X")
gate_sz = gr.Button("√Z")
with gr.Row():
gate_rx_pi = gr.Button("Rx(π)")
gate_rx_pi2 = gr.Button("Rx(π/2)")
gate_ry_pi = gr.Button("Ry(π)")
gate_ry_pi2 = gr.Button("Ry(π/2)")
with gr.Row():
gate_rz_pi = gr.Button("Rz(π)")
gate_rz_pi2 = gr.Button("Rz(π/2)")
gate_I = gr.Button("I")
gate_Idg = gr.Button("I†") # will map to I
with gr.Row():
gate_S = gr.Button("S")
gate_Sdg = gr.Button("S†")
target = gr.Dropdown(choices=[0, 1], value=0, label="Target qubit")
with gr.Row():
apply_gate_btn = gr.Button("Apply Selected Gate", variant="primary")
sample_btn = gr.Button("Sample shots")
measure_btn = gr.Button("Measure + Collapse")
gr.Markdown("**CNOT**")
with gr.Row():
control = gr.Dropdown(choices=[0, 1], value=0, label="Control")
cnot_target = gr.Dropdown(choices=[0, 1], value=1, label="Target")
apply_cnot_btn = gr.Button("Apply CNOT")
with gr.Group(elem_classes=["card"]):
gr.Markdown("Circuit Diagram
")
circuit_html = gr.HTML()
with gr.Row():
with gr.Column(scale=7):
with gr.Group(elem_classes=["card"]):
gr.Markdown("Statevector
")
ket_out = gr.Code(label="", language="python")
gr.Markdown("Top probabilities
")
probs_out = gr.Dataframe(headers=["bitstring", "prob"], interactive=False)
with gr.Column(scale=5):
with gr.Group(elem_classes=["card"]):
gr.Markdown("Measurement distribution
")
counts_out = gr.Textbox(lines=10)
with gr.Group(elem_classes=["card"]):
gr.Markdown("Export
")
qasm_dl = gr.DownloadButton("Download OpenQASM 2.0")
qiskit_dl = gr.DownloadButton("Download Qiskit code")
cirq_dl = gr.DownloadButton("Download Cirq code")
csv_dl = gr.DownloadButton("Download CSV")
skill_dl = gr.DownloadButton("Download Skill script")
with gr.Group(elem_classes=["card"]):
gr.Markdown("Heatmap
")
chip_rows = gr.Slider(2, 64, value=8, step=1, label="Chip rows")
chip_cols = gr.Slider(2, 64, value=8, step=1, label="Chip cols")
heat_metric = gr.Dropdown(
choices=[
"activity_count",
"activity_norm",
"gate_error",
"readout_error",
"decoherence_risk",
"fidelity",
"composite_risk",
],
value="activity_count",
label="Heatmap metric",
)
calibration_json = gr.Textbox(
lines=6,
label="Calibration JSON (optional)",
placeholder='{"qubits":{"0":{"gate_error":0.012,"readout_error":0.02,"t1_us":82,"t2_us":61,"fidelity":0.991}}}',
)
gr.Markdown("Composite risk weights are normalized automatically.
")
with gr.Row():
w_activity = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Weight: activity")
w_gate = gr.Slider(0.0, 1.0, value=0.20, step=0.01, label="Weight: gate error")
w_readout = gr.Slider(0.0, 1.0, value=0.15, step=0.01, label="Weight: readout error")
with gr.Row():
w_decoherence = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Weight: decoherence")
w_fidelity = gr.Slider(0.0, 1.0, value=0.15, step=0.01, label="Weight: fidelity risk")
with gr.Row():
thr_warning = gr.Slider(0.0, 1.0, value=0.45, step=0.01, label="Threshold: warning")
thr_critical = gr.Slider(0.0, 1.0, value=0.70, step=0.01, label="Threshold: critical")
hotspot_top_k = gr.Slider(1, 64, value=16, step=1, label="Hotspot rows")
metrics_csv_dl = gr.DownloadButton("Download metrics CSV")
heat_btn = gr.Button("Generate heatmap from CSV", variant="secondary")
heat_plot = gr.Plot()
hotspot_status = gr.Markdown()
hotspot_table = gr.Dataframe(
headers=[
"qubit",
"status",
"composite_risk",
"activity_count",
"gate_error",
"readout_error",
"decoherence_risk",
"fidelity",
],
interactive=False,
label="Hotspot ranking (highest composite risk first)",
)
with gr.Group(elem_classes=["card"]):
gr.Markdown("Explain (GPT-4o)
")
explain_btn = gr.Button("Explain", variant="primary")
llm_out = gr.Markdown()
with gr.Row():
dl_md = gr.DownloadButton("Download Explanation (Markdown)")
dl_pdf = gr.DownloadButton("Download Explanation (PDF)")
gr.Markdown("Tip: sample shots first.
")
# init
def _init_all(n):
qc, last_counts, selected_gate = _new_sim(n)
return qc, last_counts, selected_gate
demo.load(
fn=_init_all,
inputs=[n_qubits],
outputs=[qc_state, last_counts_state, selected_gate_state],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
n_qubits.change(
fn=_on_qubit_count_change,
inputs=[n_qubits],
outputs=[qc_state, last_counts_state, selected_gate_state, target, control, cnot_target, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
reset_btn.click(
fn=init_or_reset,
inputs=[n_qubits],
outputs=[qc_state, last_counts_state, selected_gate_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
gate_H.click(fn=lambda: set_gate("H"), outputs=[selected_gate_state, status])
gate_T.click(fn=lambda: set_gate("T"), outputs=[selected_gate_state, status])
gate_Tdg.click(fn=lambda: set_gate("T†"), outputs=[selected_gate_state, status])
gate_X.click(fn=lambda: set_gate("X"), outputs=[selected_gate_state, status])
gate_Y.click(fn=lambda: set_gate("Y"), outputs=[selected_gate_state, status])
gate_Z.click(fn=lambda: set_gate("Z"), outputs=[selected_gate_state, status])
gate_sx.click(fn=lambda: set_gate("√X"), outputs=[selected_gate_state, status])
gate_sz.click(fn=lambda: set_gate("√Z"), outputs=[selected_gate_state, status])
gate_rx_pi.click(fn=lambda: set_gate("RX(π)"), outputs=[selected_gate_state, status])
gate_rx_pi2.click(fn=lambda: set_gate("RX(π/2)"), outputs=[selected_gate_state, status])
gate_ry_pi.click(fn=lambda: set_gate("RY(π)"), outputs=[selected_gate_state, status])
gate_ry_pi2.click(fn=lambda: set_gate("RY(π/2)"), outputs=[selected_gate_state, status])
gate_rz_pi.click(fn=lambda: set_gate("RZ(π)"), outputs=[selected_gate_state, status])
gate_rz_pi2.click(fn=lambda: set_gate("RZ(π/2)"), outputs=[selected_gate_state, status])
gate_I.click(fn=lambda: set_gate("I"), outputs=[selected_gate_state, status])
gate_Idg.click(fn=lambda: set_gate("I"), outputs=[selected_gate_state, status]) # treat I† as I
gate_S.click(fn=lambda: set_gate("S"), outputs=[selected_gate_state, status])
gate_Sdg.click(fn=lambda: set_gate("S†"), outputs=[selected_gate_state, status])
apply_gate_btn.click(
fn=apply_selected_gate,
inputs=[qc_state, last_counts_state, selected_gate_state, target],
outputs=[qc_state, last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
apply_cnot_btn.click(
fn=apply_cnot,
inputs=[qc_state, last_counts_state, control, cnot_target],
outputs=[qc_state, last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
sample_btn.click(
fn=sample_shots,
inputs=[qc_state, shots],
outputs=[last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
measure_btn.click(
fn=measure_collapse,
inputs=[qc_state, shots],
outputs=[last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
qasm_dl.click(fn=dl_qasm, inputs=[qc_state, n_qubits], outputs=[qasm_dl])
qiskit_dl.click(fn=dl_qiskit, inputs=[qc_state, n_qubits], outputs=[qiskit_dl])
cirq_dl.click(fn=dl_cirq, inputs=[qc_state, n_qubits], outputs=[cirq_dl])
csv_dl.click(fn=dl_csv, inputs=[qc_state], outputs=[csv_dl])
skill_dl.click(fn=dl_skill, inputs=[qc_state, n_qubits], outputs=[skill_dl])
explain_btn.click(
fn=explain_llm,
inputs=[qc_state, n_qubits, shots, last_explained_hash, explanation_md],
outputs=[llm_out, last_explained_hash, explanation_md],
)
def _heat_and_hotspots_from_current(
qc,
n_qubits,
rows,
cols,
metric,
calibration_text,
activity_w,
gate_error_w,
readout_error_w,
decoherence_w,
fidelity_w,
warning_thr,
critical_thr,
top_k,
):
csv_text = to_csv(qc.history) # must exist from Task 2A
cfg = HeatmapConfig(rows=int(rows), cols=int(cols))
weights, thresholds = _metric_controls_to_models(
activity_w,
gate_error_w,
readout_error_w,
decoherence_w,
fidelity_w,
warning_thr,
critical_thr,
)
fig = make_metric_heatmap(
csv_text=csv_text,
n_qubits=int(n_qubits),
metric=str(metric),
cfg=cfg,
calibration_json=str(calibration_text or ""),
weights=weights,
thresholds=thresholds,
)
metrics, meta = compute_metrics_from_csv(
csv_text,
int(n_qubits),
calibration_json=str(calibration_text or ""),
weights=weights,
thresholds=thresholds,
)
hotspot_rows = _hotspot_rows(metrics, int(n_qubits), int(top_k))
note = []
skipped = int(meta.get("skipped_rows", 0))
if skipped:
note.append(f"Skipped malformed CSV rows: {skipped}")
calibration_note = str(meta.get("calibration_note", "") or "").strip()
if calibration_note:
note.append(calibration_note)
summary = " | ".join(note) if note else "Hotspot ranking updated."
return fig, summary, hotspot_rows
def _dl_metrics_csv(
qc,
n_qubits,
calibration_text,
activity_w,
gate_error_w,
readout_error_w,
decoherence_w,
fidelity_w,
warning_thr,
critical_thr,
):
csv_text = to_csv(qc.history)
weights, thresholds = _metric_controls_to_models(
activity_w,
gate_error_w,
readout_error_w,
decoherence_w,
fidelity_w,
warning_thr,
critical_thr,
)
metrics, _meta = compute_metrics_from_csv(
csv_text,
int(n_qubits),
calibration_json=str(calibration_text or ""),
weights=weights,
thresholds=thresholds,
)
return _write_tmp("qubit_metrics.csv", to_metrics_csv(metrics))
heat_btn.click(
fn=_heat_and_hotspots_from_current,
inputs=[
qc_state,
n_qubits,
chip_rows,
chip_cols,
heat_metric,
calibration_json,
w_activity,
w_gate,
w_readout,
w_decoherence,
w_fidelity,
thr_warning,
thr_critical,
hotspot_top_k,
],
outputs=[heat_plot, hotspot_status, hotspot_table],
)
metrics_csv_dl.click(
fn=_dl_metrics_csv,
inputs=[
qc_state,
n_qubits,
calibration_json,
w_activity,
w_gate,
w_readout,
w_decoherence,
w_fidelity,
thr_warning,
thr_critical,
],
outputs=[metrics_csv_dl],
)
def dl_explain_md(md_text):
return _write_tmp("explanation.md", md_text)
def dl_explain_pdf(md_text):
fd, path = tempfile.mkstemp(
prefix="explanation_",
suffix=".pdf",
dir=tempfile.gettempdir(),
)
os.close(fd)
md_to_pdf(md_text, path)
return path
dl_md.click(fn=dl_explain_md, inputs=[explanation_md], outputs=[dl_md])
dl_pdf.click(fn=dl_explain_pdf, inputs=[explanation_md], outputs=[dl_pdf])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)