Spaces:
Sleeping
Sleeping
File size: 5,196 Bytes
0b89610 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | from __future__ import annotations
import requests
import streamlit as st
API_URL = st.secrets.get("API_URL", "http://localhost:7860") if hasattr(st, "secrets") else "http://localhost:7860"
st.set_page_config(page_title="rag-context-optimizer", page_icon="R", layout="wide")
st.title("RAG Context Optimizer")
st.caption("Use any prompt, keep the token budget tight, and let the optimizer pick the best evidence per token.")
def api_get(path: str):
response = requests.get(f"{API_URL}{path}", timeout=20)
response.raise_for_status()
return response.json()
def api_post(path: str, payload: dict | None = None):
response = requests.post(f"{API_URL}{path}", json=payload or {}, timeout=20)
response.raise_for_status()
return response.json()
def start_episode(task_name: str, query: str, token_budget: int, max_steps: int):
st.session_state["payload"] = api_post(
"/reset",
{
"task_name": task_name,
"custom_query": query,
"token_budget": token_budget,
"max_steps": max_steps,
},
)
def do_step(payload: dict):
st.session_state["payload"] = api_post("/step", payload)
tasks = api_get("/tasks")
task_map = {task["name"]: task for task in tasks}
selected_task = st.sidebar.selectbox("Task preset", list(task_map))
task_meta = task_map[selected_task]
default_query = st.session_state.get("custom_query", "")
custom_query = st.sidebar.text_area(
"Custom prompt",
value=default_query,
height=180,
placeholder="Enter any prompt you want to optimize for minimal token usage.",
)
token_budget = st.sidebar.number_input(
"Token budget",
min_value=50,
value=int(task_meta["token_budget"]),
step=10,
)
max_steps = st.sidebar.number_input(
"Max steps",
min_value=1,
value=int(task_meta["max_steps"]),
step=1,
)
st.session_state["custom_query"] = custom_query
sidebar_cols = st.sidebar.columns(2)
if sidebar_cols[0].button("Start / Reset", use_container_width=True):
if not custom_query.strip():
st.sidebar.error("Enter a custom prompt first.")
else:
start_episode(selected_task, custom_query.strip(), int(token_budget), int(max_steps))
st.rerun()
if sidebar_cols[1].button("Refresh", use_container_width=True):
st.rerun()
if "payload" not in st.session_state:
st.info("Add your prompt in the sidebar and press Start / Reset.")
st.stop()
payload = st.session_state["payload"]
observation = payload["observation"]
col1, col2, col3, col4 = st.columns(4)
col1.metric("Task", observation["task_name"])
col2.metric("Budget", observation["token_budget"])
col3.metric("Used", observation["total_tokens_used"])
col4.metric("Step", observation["step_number"])
st.subheader("Active Query")
st.info(observation["query"])
feedback = observation.get("last_action_feedback")
if feedback:
st.warning(feedback)
if payload.get("info", {}).get("grader_breakdown"):
st.success(f"Final score: {payload.get('reward', 0):.4f}")
st.json(payload["info"]["grader_breakdown"])
action_cols = st.columns(3)
if action_cols[0].button("Auto Optimize Step", use_container_width=True):
suggestion = api_post("/optimize-step")
do_step(suggestion)
st.rerun()
if action_cols[1].button("Auto Run", use_container_width=True):
for _ in range(20):
suggestion = api_post("/optimize-step")
do_step(suggestion)
if suggestion["action_type"] == "submit_answer" or st.session_state["payload"]["done"]:
break
st.rerun()
manual_answer = action_cols[2].text_input("Manual answer", value="")
if st.button("Submit Manual Answer", type="primary", use_container_width=True):
do_step(
{
"action_type": "submit_answer",
"answer": manual_answer.strip() or "Concise answer synthesized from the selected evidence.",
}
)
st.rerun()
st.subheader("Available Chunks")
chunk_columns = st.columns(2)
for index, chunk in enumerate(observation["available_chunks"]):
selected = chunk["chunk_id"] in set(observation["selected_chunks"])
container = chunk_columns[index % 2].container(border=True)
container.markdown(f"**{chunk['chunk_id']}**")
container.caption(f"{chunk['domain']} | {chunk['tokens']} tokens")
container.write(", ".join(chunk["keywords"]))
c1, c2 = container.columns(2)
if selected:
if c1.button("Deselect", key=f"deselect-{chunk['chunk_id']}", use_container_width=True):
do_step({"action_type": "deselect_chunk", "chunk_id": chunk["chunk_id"]})
st.rerun()
else:
if c1.button("Select", key=f"select-{chunk['chunk_id']}", use_container_width=True):
do_step({"action_type": "select_chunk", "chunk_id": chunk["chunk_id"]})
st.rerun()
if c2.button("Compress 50%", key=f"compress-{chunk['chunk_id']}", use_container_width=True):
do_step(
{
"action_type": "compress_chunk",
"chunk_id": chunk["chunk_id"],
"compression_ratio": 0.5,
}
)
st.rerun()
st.subheader("Observation")
st.json(payload)
st.subheader("State")
st.json(api_get("/state"))
|