QureadAI / app.py
hchevva's picture
Update app.py
d8031b4 verified
# app.py
import os
import tempfile
import pathlib
import hashlib
import gradio as gr
import numpy as np
from quread.engine import QuantumStateVector
from quread.exporters import to_openqasm2, to_qiskit, to_cirq, to_csv, csv_to_skill
from quread.llm_explain_openai import explain_with_gpt4o
from quread.circuit_diagram import draw_circuit_svg
from quread.cost_guard import allow_request
from quread.export_pdf import md_to_pdf
from quread.heatmap import make_activity_heatmap, HeatmapConfig
# --- Qubit cap (configurable) ---
DEFAULT_MAX_QUBITS = 16 # safe default for CPU Spaces; change if you want
MAX_QUBITS = int(os.getenv("QUREAD_MAX_QUBITS", DEFAULT_MAX_QUBITS))
# ---------- Helpers ----------
def _new_sim(n_qubits: int):
qc = QuantumStateVector(int(n_qubits))
last_counts = None
selected_gate = "H"
return qc, last_counts, selected_gate
def _probs_top(qc, n_qubits: int, k: int = 6):
probs = (np.abs(qc.state) ** 2)
top = sorted(
[(format(i, f"0{n_qubits}b"), float(p)) for i, p in enumerate(probs)],
key=lambda x: x[1],
reverse=True
)[:k]
return top
def _counts_str(counts):
if not counts:
return ""
return "\n".join([f"{k}: {v}" for k, v in counts.items()])
def _write_tmp(filename: str, content: str) -> str:
p = pathlib.Path(tempfile.gettempdir()) / filename
p.write_text(content, encoding="utf-8")
return str(p)
def _circuit_hash(history):
return hashlib.sha256(str(history).encode()).hexdigest()
def dl_qasm(qc, n_qubits):
return _write_tmp("circuit.qasm", to_openqasm2(qc.history, n_qubits=int(n_qubits)))
def dl_qiskit(qc, n_qubits):
return _write_tmp("circuit_qiskit.py", to_qiskit(qc.history, n_qubits=int(n_qubits)))
def dl_cirq(qc, n_qubits):
return _write_tmp("circuit_cirq.py", to_cirq(qc.history, n_qubits=int(n_qubits)))
def dl_csv(qc):
return _write_tmp("circuit.csv", to_csv(qc.history))
def dl_skill(qc, n_qubits):
csv_text = to_csv(qc.history)
skill_text = csv_to_skill(csv_text, n_qubits=int(n_qubits))
return _write_tmp("circuit.il", skill_text)
def update_views(qc, last_counts, n_qubits):
svg = draw_circuit_svg(qc.history, n_qubits=int(n_qubits))
ket = qc.ket_notation(max_terms=16)
probs_top = _probs_top(qc, int(n_qubits), k=6)
counts_text = _counts_str(last_counts)
return svg, ket, counts_text, probs_top
def init_or_reset(n_qubits):
qc, last_counts, selected_gate = _new_sim(n_qubits)
return qc, last_counts, selected_gate, "✅ Reset done."
def set_gate(gate):
return gate, f"✅ Selected gate: {gate}"
def apply_selected_gate(qc, last_counts, selected_gate, target):
qc.apply_single(selected_gate, target=int(target))
last_counts = None
return qc, last_counts, f"✅ Applied {selected_gate} on q{target}."
def apply_cnot(qc, last_counts, control, target):
if int(control) == int(target):
return qc, last_counts, "❌ Control and target must be different."
qc.apply_cnot(control=int(control), target=int(target))
last_counts = None
return qc, last_counts, f"✅ Applied CNOT (q{control} -> q{target})."
def sample_shots(qc, shots):
last_counts = qc.sample(shots=int(shots))
return last_counts, "✅ Sampled shots."
def measure_collapse(qc, shots):
res = qc.measure_collapse()
last_counts = qc.sample(shots=int(shots))
return last_counts, f"✅ Collapsed to |{res}⟩ and sampled shots."
def explain_llm(qc, n_qubits, shots, last_hash):
# circuit-change gating
curr_hash = _circuit_hash(qc.history)
if curr_hash == last_hash:
return "ℹ️ Circuit unchanged. Explanation reused.", last_hash, ""
# cost guard
EST_TOKENS = 900
if not allow_request(EST_TOKENS):
return "🚫 Explanation disabled (daily token limit reached).", last_hash, ""
state_ket = qc.ket_notation(max_terms=6)
probs_top = _probs_top(qc, int(n_qubits), k=6)
explanation = explain_with_gpt4o(
n_qubits=int(n_qubits),
history=qc.history,
state_ket=state_ket,
probs_top=probs_top,
shots=int(shots),
)
return explanation, curr_hash, explanation
def _refresh_choices(n):
opts = list(range(int(n)))
return (
gr.Dropdown(choices=opts, value=0),
gr.Dropdown(choices=opts, value=0),
gr.Dropdown(choices=opts, value=min(1, int(n) - 1)),
)
# ---------- Styling ----------
CSS = """
#title h1 { font-size: 42px !important; margin-bottom: 6px; }
#subtitle { color: #6b7280; margin-top: 0px; }
.sidebar {
background: #f6f7fb;
border-right: 1px solid #e5e7eb;
border-radius: 14px;
padding: 18px 14px;
height: calc(100vh - 34px);
position: sticky;
top: 16px;
}
.card {
background: white;
border: 1px solid #e5e7eb;
border-radius: 14px;
padding: 14px;
}
.section-title { font-size: 22px; font-weight: 700; margin: 6px 0 10px; }
.small-note { color: #6b7280; font-size: 12px; }
"""
theme = gr.themes.Soft(radius_size="lg", text_size="md")
with gr.Blocks(theme=theme, css=CSS, title="Quread.ai — State Vector Studio") as demo:
qc_state = gr.State()
last_counts_state = gr.State()
selected_gate_state = gr.State()
explanation_md = gr.State("")
last_explained_hash = gr.State("")
with gr.Row():
# Sidebar
with gr.Column(scale=3, elem_classes=["sidebar"]):
gr.Markdown("### Simulator Settings")
n_qubits = gr.Slider(1, MAX_QUBITS, value=2, step=1, label="Number of qubits")
gr.Markdown(f"<div class='small-note'>Max qubits: <b>{MAX_QUBITS}</b> (set env var <code>QUREAD_MAX_QUBITS</code> to change)</div>")
shots = gr.Slider(128, 8192, value=1024, step=128, label="Shots")
gr.Markdown("---")
gr.Markdown("### Explanation (GPT-4o)")
gr.Markdown("<div class='small-note'>Uses GPT-4o with cost guards.</div>")
reset_btn = gr.Button("Reset Simulator", variant="secondary")
# Main
with gr.Column(scale=9):
gr.Markdown("<div id='title'><h1>Quread.ai — State Vector Studio</h1></div>")
gr.Markdown("<div id='subtitle'>Build circuits, visualize, export, and get teaching-ready explanations.</div>")
status = gr.Markdown("")
with gr.Group(elem_classes=["card"]):
gr.Markdown("<div class='section-title'>Gate Palette</div>")
with gr.Row():
gate_H = gr.Button("H")
gate_T = gr.Button("T")
gate_Tdg = gr.Button("T†")
gate_X = gr.Button("X")
with gr.Row():
gate_Y = gr.Button("Y")
gate_Z = gr.Button("Z")
gate_sx = gr.Button("√X")
gate_sz = gr.Button("√Z")
with gr.Row():
gate_rx_pi = gr.Button("Rx(π)")
gate_rx_pi2 = gr.Button("Rx(π/2)")
gate_ry_pi = gr.Button("Ry(π)")
gate_ry_pi2 = gr.Button("Ry(π/2)")
with gr.Row():
gate_rz_pi = gr.Button("Rz(π)")
gate_rz_pi2 = gr.Button("Rz(π/2)")
gate_I = gr.Button("I")
gate_Idg = gr.Button("I†") # will map to I
with gr.Row():
gate_S = gr.Button("S")
gate_Sdg = gr.Button("S†")
target = gr.Dropdown(choices=[0, 1], value=0, label="Target qubit")
with gr.Row():
apply_gate_btn = gr.Button("Apply Selected Gate", variant="primary")
sample_btn = gr.Button("Sample shots")
measure_btn = gr.Button("Measure + Collapse")
gr.Markdown("**CNOT**")
with gr.Row():
control = gr.Dropdown(choices=[0, 1], value=0, label="Control")
cnot_target = gr.Dropdown(choices=[0, 1], value=1, label="Target")
apply_cnot_btn = gr.Button("Apply CNOT")
with gr.Group(elem_classes=["card"]):
gr.Markdown("<div class='section-title'>Circuit Diagram</div>")
circuit_html = gr.HTML()
with gr.Row():
with gr.Column(scale=7):
with gr.Group(elem_classes=["card"]):
gr.Markdown("<div class='section-title'>Statevector</div>")
ket_out = gr.Code(label="", language="python")
gr.Markdown("<div class='section-title'>Top probabilities</div>")
probs_out = gr.Dataframe(headers=["bitstring", "prob"], interactive=False)
with gr.Column(scale=5):
with gr.Group(elem_classes=["card"]):
gr.Markdown("<div class='section-title'>Measurement distribution</div>")
counts_out = gr.Textbox(lines=10)
with gr.Group(elem_classes=["card"]):
gr.Markdown("<div class='section-title'>Export</div>")
qasm_dl = gr.DownloadButton("Download OpenQASM 2.0")
qiskit_dl = gr.DownloadButton("Download Qiskit code")
cirq_dl = gr.DownloadButton("Download Cirq code")
csv_dl = gr.DownloadButton("Download CSV")
skill_dl = gr.DownloadButton("Download Skill script")
with gr.Group(elem_classes=["card"]):
gr.Markdown("<div class='section-title'>Heatmap</div>")
chip_rows = gr.Slider(2, 64, value=8, step=1, label="Chip rows")
chip_cols = gr.Slider(2, 64, value=8, step=1, label="Chip cols")
heat_btn = gr.Button("Generate heatmap from CSV", variant="secondary")
heat_plot = gr.Plot()
with gr.Group(elem_classes=["card"]):
gr.Markdown("<div class='section-title'>Explain (GPT-4o)</div>")
explain_btn = gr.Button("Explain", variant="primary")
llm_out = gr.Markdown()
with gr.Row():
dl_md = gr.DownloadButton("Download Explanation (Markdown)")
dl_pdf = gr.DownloadButton("Download Explanation (PDF)")
gr.Markdown("<div class='small-note'>Tip: sample shots first.</div>")
# init
def _init_all(n):
qc, last_counts, selected_gate = _new_sim(n)
return qc, last_counts, selected_gate
demo.load(
fn=_init_all,
inputs=[n_qubits],
outputs=[qc_state, last_counts_state, selected_gate_state],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
n_qubits.change(
fn=_refresh_choices,
inputs=[n_qubits],
outputs=[target, control, cnot_target],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
reset_btn.click(
fn=init_or_reset,
inputs=[n_qubits],
outputs=[qc_state, last_counts_state, selected_gate_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
gate_H.click(fn=lambda: set_gate("H"), outputs=[selected_gate_state, status])
gate_T.click(fn=lambda: set_gate("T"), outputs=[selected_gate_state, status])
gate_Tdg.click(fn=lambda: set_gate("T†"), outputs=[selected_gate_state, status])
gate_X.click(fn=lambda: set_gate("X"), outputs=[selected_gate_state, status])
gate_Y.click(fn=lambda: set_gate("Y"), outputs=[selected_gate_state, status])
gate_Z.click(fn=lambda: set_gate("Z"), outputs=[selected_gate_state, status])
gate_sx.click(fn=lambda: set_gate("√X"), outputs=[selected_gate_state, status])
gate_sz.click(fn=lambda: set_gate("√Z"), outputs=[selected_gate_state, status])
gate_rx_pi.click(fn=lambda: set_gate("RX(π)"), outputs=[selected_gate_state, status])
gate_rx_pi2.click(fn=lambda: set_gate("RX(π/2)"), outputs=[selected_gate_state, status])
gate_ry_pi.click(fn=lambda: set_gate("RY(π)"), outputs=[selected_gate_state, status])
gate_ry_pi2.click(fn=lambda: set_gate("RY(π/2)"), outputs=[selected_gate_state, status])
gate_rz_pi.click(fn=lambda: set_gate("RZ(π)"), outputs=[selected_gate_state, status])
gate_rz_pi2.click(fn=lambda: set_gate("RZ(π/2)"), outputs=[selected_gate_state, status])
gate_I.click(fn=lambda: set_gate("I"), outputs=[selected_gate_state, status])
gate_Idg.click(fn=lambda: set_gate("I"), outputs=[selected_gate_state, status]) # treat I† as I
gate_S.click(fn=lambda: set_gate("S"), outputs=[selected_gate_state, status])
gate_Sdg.click(fn=lambda: set_gate("S†"), outputs=[selected_gate_state, status])
apply_gate_btn.click(
fn=apply_selected_gate,
inputs=[qc_state, last_counts_state, selected_gate_state, target],
outputs=[qc_state, last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
apply_cnot_btn.click(
fn=apply_cnot,
inputs=[qc_state, last_counts_state, control, cnot_target],
outputs=[qc_state, last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
sample_btn.click(
fn=sample_shots,
inputs=[qc_state, shots],
outputs=[last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
measure_btn.click(
fn=measure_collapse,
inputs=[qc_state, shots],
outputs=[last_counts_state, status],
).then(
fn=update_views,
inputs=[qc_state, last_counts_state, n_qubits],
outputs=[circuit_html, ket_out, counts_out, probs_out],
)
qasm_dl.click(fn=dl_qasm, inputs=[qc_state, n_qubits], outputs=[qasm_dl])
qiskit_dl.click(fn=dl_qiskit, inputs=[qc_state, n_qubits], outputs=[qiskit_dl])
cirq_dl.click(fn=dl_cirq, inputs=[qc_state, n_qubits], outputs=[cirq_dl])
csv_dl.click(fn=dl_csv, inputs=[qc_state], outputs=[csv_dl])
skill_dl.click(fn=dl_skill, inputs=[qc_state, n_qubits], outputs=[skill_dl])
explain_btn.click(
fn=explain_llm,
inputs=[qc_state, n_qubits, shots, last_explained_hash],
outputs=[llm_out, last_explained_hash, explanation_md],
)
def _heat_from_current(qc, n_qubits, rows, cols):
csv_text = to_csv(qc.history) # must exist from Task 2A
cfg = HeatmapConfig(rows=int(rows), cols=int(cols))
return make_activity_heatmap(csv_text, int(n_qubits), cfg=cfg)
heat_btn.click(
fn=_heat_from_current,
inputs=[qc_state, n_qubits, chip_rows, chip_cols],
outputs=[heat_plot],
)
def dl_explain_md(md_text):
return _write_tmp("explanation.md", md_text)
def dl_explain_pdf(md_text):
path = pathlib.Path(tempfile.gettempdir()) / "explanation.pdf"
md_to_pdf(md_text, str(path))
return str(path)
dl_md.click(fn=dl_explain_md, inputs=[explanation_md], outputs=[dl_md])
dl_pdf.click(fn=dl_explain_pdf, inputs=[explanation_md], outputs=[dl_pdf])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)