Spaces:
Running
Running
Upload 4 files
Browse files- app.py +78 -32
- pi_wrapper.py +34 -1
app.py
CHANGED
|
@@ -1230,6 +1230,22 @@ def _resolve_example_corpus(corpus_dir: Path, question: str) -> Path:
|
|
| 1230 |
return candidate if candidate.exists() else corpus_dir
|
| 1231 |
|
| 1232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1233 |
def _ensure_corpus_ready(
|
| 1234 |
use_default: bool,
|
| 1235 |
uploaded_zip: Optional[str],
|
|
@@ -1836,32 +1852,58 @@ def run_search(
|
|
| 1836 |
mode: str,
|
| 1837 |
corpus_source: str,
|
| 1838 |
uploaded_zip: Optional[str],
|
| 1839 |
-
|
| 1840 |
-
|
|
|
|
| 1841 |
|
| 1842 |
if not api_key or not api_key.strip():
|
| 1843 |
-
yield _TERM_IDLE, "⚠ OpenAI API Key is required."
|
| 1844 |
return
|
| 1845 |
if not question or not question.strip():
|
| 1846 |
-
yield _TERM_IDLE, "⚠ Please enter a question."
|
| 1847 |
return
|
| 1848 |
|
|
|
|
| 1849 |
use_default_corpus = corpus_source == DEFAULT_CORPUS_LABEL
|
| 1850 |
selected_domain = _selected_example_corpus(question) if use_default_corpus else None
|
| 1851 |
-
|
| 1852 |
-
|
| 1853 |
-
|
| 1854 |
-
|
| 1855 |
-
|
| 1856 |
-
|
| 1857 |
-
|
| 1858 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1859 |
|
|
|
|
| 1860 |
cwd = _resolve_example_corpus(corpus_root, question) if use_default_corpus else corpus_root
|
| 1861 |
|
| 1862 |
prompt = build_ir_prompt(question, cwd) if mode == "IR" else build_benchmark_prompt(question, cwd)
|
| 1863 |
state = _initial_terminal_state(question, cwd, mode, model, max_turns, selected_domain)
|
| 1864 |
-
yield _render_terminal_state(state), "⚙ Running…"
|
| 1865 |
|
| 1866 |
try:
|
| 1867 |
for event in run_pi_stream(
|
|
@@ -1871,26 +1913,27 @@ def run_search(
|
|
| 1871 |
provider="openai",
|
| 1872 |
model=model,
|
| 1873 |
max_turns=max_turns,
|
|
|
|
|
|
|
| 1874 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1875 |
if _apply_event(state, event):
|
| 1876 |
status = "✓ Completed" if event.get("type") == "__final__" else "⚙ Running…"
|
| 1877 |
-
yield _render_terminal_state(state), status
|
| 1878 |
|
| 1879 |
except Exception as exc:
|
| 1880 |
_apply_event(state, {"type": "error", "error": str(exc)})
|
| 1881 |
-
yield _render_terminal_state(state), f"⚠ {exc}"
|
| 1882 |
return
|
| 1883 |
-
finally:
|
| 1884 |
-
try:
|
| 1885 |
-
p = cwd.parent
|
| 1886 |
-
if p.name.startswith("pi_corpus_"):
|
| 1887 |
-
shutil.rmtree(p, ignore_errors=True)
|
| 1888 |
-
except Exception:
|
| 1889 |
-
pass
|
| 1890 |
|
| 1891 |
if not state.get("finalized"):
|
| 1892 |
_apply_event(state, {"type": "__final__", "text": ""})
|
| 1893 |
-
yield _render_terminal_state(state), "✓ Completed"
|
| 1894 |
|
| 1895 |
|
| 1896 |
# ──────────────────────────────────────────────────────────────────────────
|
|
@@ -1899,6 +1942,7 @@ def run_search(
|
|
| 1899 |
def build_ui() -> gr.Blocks:
|
| 1900 |
hero_logo_src = f"/gradio_api/file={HERO_LOGO_PATH}"
|
| 1901 |
with gr.Blocks(title="DCI-Agent Search") as demo:
|
|
|
|
| 1902 |
with gr.Row(equal_height=False):
|
| 1903 |
|
| 1904 |
# ── Left sidebar ───────────────────────────────────────────────
|
|
@@ -2053,28 +2097,30 @@ def build_ui() -> gr.Blocks:
|
|
| 2053 |
outputs=[uploaded_zip, corpus_hint],
|
| 2054 |
)
|
| 2055 |
|
| 2056 |
-
def clear_all():
|
|
|
|
| 2057 |
return (
|
| 2058 |
-
_TERM_IDLE, "", "Ready ✔",
|
| 2059 |
)
|
| 2060 |
|
| 2061 |
clear_btn.click(
|
| 2062 |
fn=clear_all,
|
| 2063 |
-
|
|
|
|
| 2064 |
)
|
| 2065 |
|
| 2066 |
-
def _run_with_qa_mode(api_key, question, model, max_turns, corpus_source, uploaded_zip):
|
| 2067 |
-
yield from run_search(api_key, question, model, max_turns, "QA", corpus_source, uploaded_zip)
|
| 2068 |
|
| 2069 |
run_event = run_btn.click(
|
| 2070 |
fn=_run_with_qa_mode,
|
| 2071 |
-
inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip],
|
| 2072 |
-
outputs=[terminal, status_text],
|
| 2073 |
)
|
| 2074 |
terminal_question.submit(
|
| 2075 |
fn=_run_with_qa_mode,
|
| 2076 |
-
inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip],
|
| 2077 |
-
outputs=[terminal, status_text],
|
| 2078 |
)
|
| 2079 |
stop_btn.click(fn=None, cancels=[run_event])
|
| 2080 |
|
|
|
|
| 1230 |
return candidate if candidate.exists() else corpus_dir
|
| 1231 |
|
| 1232 |
|
| 1233 |
+
def _cleanup_runtime_state(runtime_state: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 1234 |
+
state = dict(runtime_state or {})
|
| 1235 |
+
corpus_root = state.get("corpus_root")
|
| 1236 |
+
session_dir = state.get("session_dir")
|
| 1237 |
+
for path_value in (corpus_root, session_dir):
|
| 1238 |
+
if not path_value:
|
| 1239 |
+
continue
|
| 1240 |
+
try:
|
| 1241 |
+
path = Path(path_value)
|
| 1242 |
+
if path.exists():
|
| 1243 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 1244 |
+
except Exception:
|
| 1245 |
+
pass
|
| 1246 |
+
return {}
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
def _ensure_corpus_ready(
|
| 1250 |
use_default: bool,
|
| 1251 |
uploaded_zip: Optional[str],
|
|
|
|
| 1852 |
mode: str,
|
| 1853 |
corpus_source: str,
|
| 1854 |
uploaded_zip: Optional[str],
|
| 1855 |
+
runtime_state: Optional[Dict[str, Any]],
|
| 1856 |
+
) -> Generator[tuple[str, str, Dict[str, Any]], None, None]:
|
| 1857 |
+
"""Yields (terminal_html, status_str, runtime_state)."""
|
| 1858 |
|
| 1859 |
if not api_key or not api_key.strip():
|
| 1860 |
+
yield _TERM_IDLE, "⚠ OpenAI API Key is required.", dict(runtime_state or {})
|
| 1861 |
return
|
| 1862 |
if not question or not question.strip():
|
| 1863 |
+
yield _TERM_IDLE, "⚠ Please enter a question.", dict(runtime_state or {})
|
| 1864 |
return
|
| 1865 |
|
| 1866 |
+
runtime = dict(runtime_state or {})
|
| 1867 |
use_default_corpus = corpus_source == DEFAULT_CORPUS_LABEL
|
| 1868 |
selected_domain = _selected_example_corpus(question) if use_default_corpus else None
|
| 1869 |
+
corpus_key = json.dumps(
|
| 1870 |
+
{
|
| 1871 |
+
"source": corpus_source,
|
| 1872 |
+
"uploaded_zip": uploaded_zip or "",
|
| 1873 |
+
"selected_domain": selected_domain or "",
|
| 1874 |
+
},
|
| 1875 |
+
ensure_ascii=False,
|
| 1876 |
+
sort_keys=True,
|
| 1877 |
+
)
|
| 1878 |
+
corpus_root_value = runtime.get("corpus_root")
|
| 1879 |
+
corpus_root = Path(corpus_root_value) if corpus_root_value else None
|
| 1880 |
+
if runtime.get("corpus_key") != corpus_key or corpus_root is None or not corpus_root.exists():
|
| 1881 |
+
runtime = _cleanup_runtime_state(runtime)
|
| 1882 |
+
try:
|
| 1883 |
+
corpus_root = _ensure_corpus_ready(use_default_corpus, uploaded_zip, selected_domain)
|
| 1884 |
+
except ValueError as exc:
|
| 1885 |
+
yield _TERM_IDLE, f"⚠ {exc}", runtime
|
| 1886 |
+
return
|
| 1887 |
+
except Exception as exc:
|
| 1888 |
+
yield _TERM_IDLE, f"⚠ Corpus error: {exc}", runtime
|
| 1889 |
+
return
|
| 1890 |
+
session_dir = Path(tempfile.mkdtemp(prefix="pi_session_"))
|
| 1891 |
+
runtime.update(
|
| 1892 |
+
{
|
| 1893 |
+
"corpus_key": corpus_key,
|
| 1894 |
+
"corpus_root": str(corpus_root),
|
| 1895 |
+
"session_dir": str(session_dir),
|
| 1896 |
+
"session_file": "",
|
| 1897 |
+
"session_id": "",
|
| 1898 |
+
}
|
| 1899 |
+
)
|
| 1900 |
|
| 1901 |
+
assert corpus_root is not None
|
| 1902 |
cwd = _resolve_example_corpus(corpus_root, question) if use_default_corpus else corpus_root
|
| 1903 |
|
| 1904 |
prompt = build_ir_prompt(question, cwd) if mode == "IR" else build_benchmark_prompt(question, cwd)
|
| 1905 |
state = _initial_terminal_state(question, cwd, mode, model, max_turns, selected_domain)
|
| 1906 |
+
yield _render_terminal_state(state), "⚙ Running…", runtime
|
| 1907 |
|
| 1908 |
try:
|
| 1909 |
for event in run_pi_stream(
|
|
|
|
| 1913 |
provider="openai",
|
| 1914 |
model=model,
|
| 1915 |
max_turns=max_turns,
|
| 1916 |
+
session_dir=Path(runtime["session_dir"]) if runtime.get("session_dir") else None,
|
| 1917 |
+
session_path=Path(runtime["session_file"]) if runtime.get("session_file") else None,
|
| 1918 |
):
|
| 1919 |
+
if event.get("type") == "__session__":
|
| 1920 |
+
if event.get("sessionFile"):
|
| 1921 |
+
runtime["session_file"] = str(event["sessionFile"])
|
| 1922 |
+
if event.get("sessionId"):
|
| 1923 |
+
runtime["session_id"] = str(event["sessionId"])
|
| 1924 |
+
continue
|
| 1925 |
if _apply_event(state, event):
|
| 1926 |
status = "✓ Completed" if event.get("type") == "__final__" else "⚙ Running…"
|
| 1927 |
+
yield _render_terminal_state(state), status, runtime
|
| 1928 |
|
| 1929 |
except Exception as exc:
|
| 1930 |
_apply_event(state, {"type": "error", "error": str(exc)})
|
| 1931 |
+
yield _render_terminal_state(state), f"⚠ {exc}", runtime
|
| 1932 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1933 |
|
| 1934 |
if not state.get("finalized"):
|
| 1935 |
_apply_event(state, {"type": "__final__", "text": ""})
|
| 1936 |
+
yield _render_terminal_state(state), "✓ Completed", runtime
|
| 1937 |
|
| 1938 |
|
| 1939 |
# ──────────────────────────────────────────────────────────────────────────
|
|
|
|
| 1942 |
def build_ui() -> gr.Blocks:
|
| 1943 |
hero_logo_src = f"/gradio_api/file={HERO_LOGO_PATH}"
|
| 1944 |
with gr.Blocks(title="DCI-Agent Search") as demo:
|
| 1945 |
+
runtime_state = gr.State({})
|
| 1946 |
with gr.Row(equal_height=False):
|
| 1947 |
|
| 1948 |
# ── Left sidebar ───────────────────────────────────────────────
|
|
|
|
| 2097 |
outputs=[uploaded_zip, corpus_hint],
|
| 2098 |
)
|
| 2099 |
|
| 2100 |
+
def clear_all(runtime_state):
|
| 2101 |
+
cleared_runtime = _cleanup_runtime_state(runtime_state)
|
| 2102 |
return (
|
| 2103 |
+
_TERM_IDLE, "", "Ready ✔", cleared_runtime,
|
| 2104 |
)
|
| 2105 |
|
| 2106 |
clear_btn.click(
|
| 2107 |
fn=clear_all,
|
| 2108 |
+
inputs=[runtime_state],
|
| 2109 |
+
outputs=[terminal, terminal_question, status_text, runtime_state],
|
| 2110 |
)
|
| 2111 |
|
| 2112 |
+
def _run_with_qa_mode(api_key, question, model, max_turns, corpus_source, uploaded_zip, runtime_state):
|
| 2113 |
+
yield from run_search(api_key, question, model, max_turns, "QA", corpus_source, uploaded_zip, runtime_state)
|
| 2114 |
|
| 2115 |
run_event = run_btn.click(
|
| 2116 |
fn=_run_with_qa_mode,
|
| 2117 |
+
inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip, runtime_state],
|
| 2118 |
+
outputs=[terminal, status_text, runtime_state],
|
| 2119 |
)
|
| 2120 |
terminal_question.submit(
|
| 2121 |
fn=_run_with_qa_mode,
|
| 2122 |
+
inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip, runtime_state],
|
| 2123 |
+
outputs=[terminal, status_text, runtime_state],
|
| 2124 |
)
|
| 2125 |
stop_btn.click(fn=None, cancels=[run_event])
|
| 2126 |
|
pi_wrapper.py
CHANGED
|
@@ -93,6 +93,8 @@ class NativePiClient:
|
|
| 93 |
provider: str,
|
| 94 |
model: str,
|
| 95 |
tools: str,
|
|
|
|
|
|
|
| 96 |
api_key: str = "",
|
| 97 |
) -> None:
|
| 98 |
self.package_dir = package_dir
|
|
@@ -101,6 +103,8 @@ class NativePiClient:
|
|
| 101 |
self.provider = provider
|
| 102 |
self.model = model
|
| 103 |
self.tools = tools
|
|
|
|
|
|
|
| 104 |
self.api_key = api_key
|
| 105 |
self.proc: Optional[subprocess.Popen[bytes]] = None
|
| 106 |
self.stderr_chunks: List[str] = []
|
|
@@ -132,7 +136,10 @@ class NativePiClient:
|
|
| 132 |
cmd.extend(["--model", self.model])
|
| 133 |
if self.tools:
|
| 134 |
cmd.extend(["--tools", self.tools])
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
| 136 |
return cmd
|
| 137 |
|
| 138 |
def start(self) -> None:
|
|
@@ -185,6 +192,17 @@ class NativePiClient:
|
|
| 185 |
self.proc.stdin.write(line.encode("utf-8"))
|
| 186 |
self.proc.stdin.flush()
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
def _read_json_line(self) -> Dict[str, Any]:
|
| 189 |
if self.proc is None or self.proc.stdout is None:
|
| 190 |
raise RuntimeError("RPC client is not running")
|
|
@@ -624,6 +642,8 @@ def run_pi(
|
|
| 624 |
provider: str = "openai",
|
| 625 |
model: str = "gpt-4o",
|
| 626 |
max_turns: int = 6,
|
|
|
|
|
|
|
| 627 |
) -> Generator[Dict[str, Any], None, str]:
|
| 628 |
"""
|
| 629 |
Run PI (native if available, else fallback) and yield events.
|
|
@@ -637,6 +657,8 @@ def run_pi(
|
|
| 637 |
provider=provider,
|
| 638 |
model=model,
|
| 639 |
tools="read,bash",
|
|
|
|
|
|
|
| 640 |
api_key=api_key,
|
| 641 |
)
|
| 642 |
try:
|
|
@@ -666,6 +688,8 @@ def run_pi_stream(
|
|
| 666 |
provider: str = "openai",
|
| 667 |
model: str = "gpt-4o",
|
| 668 |
max_turns: int = 6,
|
|
|
|
|
|
|
| 669 |
) -> Iterator[Dict[str, Any]]:
|
| 670 |
"""
|
| 671 |
Iterator that yields PI events. The *last* item is a sentinel dict
|
|
@@ -679,11 +703,20 @@ def run_pi_stream(
|
|
| 679 |
provider=provider,
|
| 680 |
model=model,
|
| 681 |
tools="read,bash",
|
|
|
|
|
|
|
| 682 |
api_key=api_key,
|
| 683 |
)
|
| 684 |
text_parts: List[str] = []
|
| 685 |
try:
|
| 686 |
client.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
request_id = client._next_id()
|
| 688 |
client._send({"id": request_id, "type": "prompt", "message": question})
|
| 689 |
|
|
|
|
| 93 |
provider: str,
|
| 94 |
model: str,
|
| 95 |
tools: str,
|
| 96 |
+
session_dir: Optional[Path] = None,
|
| 97 |
+
session_path: Optional[Path] = None,
|
| 98 |
api_key: str = "",
|
| 99 |
) -> None:
|
| 100 |
self.package_dir = package_dir
|
|
|
|
| 103 |
self.provider = provider
|
| 104 |
self.model = model
|
| 105 |
self.tools = tools
|
| 106 |
+
self.session_dir = session_dir
|
| 107 |
+
self.session_path = session_path
|
| 108 |
self.api_key = api_key
|
| 109 |
self.proc: Optional[subprocess.Popen[bytes]] = None
|
| 110 |
self.stderr_chunks: List[str] = []
|
|
|
|
| 136 |
cmd.extend(["--model", self.model])
|
| 137 |
if self.tools:
|
| 138 |
cmd.extend(["--tools", self.tools])
|
| 139 |
+
if self.session_dir:
|
| 140 |
+
cmd.extend(["--session-dir", str(self.session_dir)])
|
| 141 |
+
if self.session_path:
|
| 142 |
+
cmd.extend(["--session", str(self.session_path)])
|
| 143 |
return cmd
|
| 144 |
|
| 145 |
def start(self) -> None:
|
|
|
|
| 192 |
self.proc.stdin.write(line.encode("utf-8"))
|
| 193 |
self.proc.stdin.flush()
|
| 194 |
|
| 195 |
+
def call(self, command_type: str, **payload: Any) -> Dict[str, Any]:
|
| 196 |
+
request_id = self._next_id()
|
| 197 |
+
message = {"id": request_id, "type": command_type, **payload}
|
| 198 |
+
self._send(message)
|
| 199 |
+
while True:
|
| 200 |
+
event = self._read_json_line()
|
| 201 |
+
if event.get("type") == "response" and event.get("id") == request_id:
|
| 202 |
+
if not event.get("success", False):
|
| 203 |
+
raise RuntimeError(f"RPC {command_type} failed: {event.get('error', 'unknown error')}")
|
| 204 |
+
return event
|
| 205 |
+
|
| 206 |
def _read_json_line(self) -> Dict[str, Any]:
|
| 207 |
if self.proc is None or self.proc.stdout is None:
|
| 208 |
raise RuntimeError("RPC client is not running")
|
|
|
|
| 642 |
provider: str = "openai",
|
| 643 |
model: str = "gpt-4o",
|
| 644 |
max_turns: int = 6,
|
| 645 |
+
session_dir: Optional[Path] = None,
|
| 646 |
+
session_path: Optional[Path] = None,
|
| 647 |
) -> Generator[Dict[str, Any], None, str]:
|
| 648 |
"""
|
| 649 |
Run PI (native if available, else fallback) and yield events.
|
|
|
|
| 657 |
provider=provider,
|
| 658 |
model=model,
|
| 659 |
tools="read,bash",
|
| 660 |
+
session_dir=session_dir,
|
| 661 |
+
session_path=session_path,
|
| 662 |
api_key=api_key,
|
| 663 |
)
|
| 664 |
try:
|
|
|
|
| 688 |
provider: str = "openai",
|
| 689 |
model: str = "gpt-4o",
|
| 690 |
max_turns: int = 6,
|
| 691 |
+
session_dir: Optional[Path] = None,
|
| 692 |
+
session_path: Optional[Path] = None,
|
| 693 |
) -> Iterator[Dict[str, Any]]:
|
| 694 |
"""
|
| 695 |
Iterator that yields PI events. The *last* item is a sentinel dict
|
|
|
|
| 703 |
provider=provider,
|
| 704 |
model=model,
|
| 705 |
tools="read,bash",
|
| 706 |
+
session_dir=session_dir,
|
| 707 |
+
session_path=session_path,
|
| 708 |
api_key=api_key,
|
| 709 |
)
|
| 710 |
text_parts: List[str] = []
|
| 711 |
try:
|
| 712 |
client.start()
|
| 713 |
+
state_response = client.call("get_state")
|
| 714 |
+
state_data = state_response.get("data", {}) or {}
|
| 715 |
+
yield {
|
| 716 |
+
"type": "__session__",
|
| 717 |
+
"sessionFile": state_data.get("sessionFile"),
|
| 718 |
+
"sessionId": state_data.get("sessionId"),
|
| 719 |
+
}
|
| 720 |
request_id = client._next_id()
|
| 721 |
client._send({"id": request_id, "type": "prompt", "message": question})
|
| 722 |
|