IPF commited on
Commit
38c7e42
·
verified ·
1 Parent(s): 5c81f1e

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +78 -32
  2. pi_wrapper.py +34 -1
app.py CHANGED
@@ -1230,6 +1230,22 @@ def _resolve_example_corpus(corpus_dir: Path, question: str) -> Path:
1230
  return candidate if candidate.exists() else corpus_dir
1231
 
1232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1233
  def _ensure_corpus_ready(
1234
  use_default: bool,
1235
  uploaded_zip: Optional[str],
@@ -1836,32 +1852,58 @@ def run_search(
1836
  mode: str,
1837
  corpus_source: str,
1838
  uploaded_zip: Optional[str],
1839
- ) -> Generator[tuple[str, str], None, None]:
1840
- """Yields (terminal_html, status_str)."""
 
1841
 
1842
  if not api_key or not api_key.strip():
1843
- yield _TERM_IDLE, "⚠ OpenAI API Key is required."
1844
  return
1845
  if not question or not question.strip():
1846
- yield _TERM_IDLE, "⚠ Please enter a question."
1847
  return
1848
 
 
1849
  use_default_corpus = corpus_source == DEFAULT_CORPUS_LABEL
1850
  selected_domain = _selected_example_corpus(question) if use_default_corpus else None
1851
- try:
1852
- corpus_root = _ensure_corpus_ready(use_default_corpus, uploaded_zip, selected_domain)
1853
- except ValueError as exc:
1854
- yield _TERM_IDLE, f"⚠ {exc}"
1855
- return
1856
- except Exception as exc:
1857
- yield _TERM_IDLE, f"⚠ Corpus error: {exc}"
1858
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1859
 
 
1860
  cwd = _resolve_example_corpus(corpus_root, question) if use_default_corpus else corpus_root
1861
 
1862
  prompt = build_ir_prompt(question, cwd) if mode == "IR" else build_benchmark_prompt(question, cwd)
1863
  state = _initial_terminal_state(question, cwd, mode, model, max_turns, selected_domain)
1864
- yield _render_terminal_state(state), "⚙ Running…"
1865
 
1866
  try:
1867
  for event in run_pi_stream(
@@ -1871,26 +1913,27 @@ def run_search(
1871
  provider="openai",
1872
  model=model,
1873
  max_turns=max_turns,
 
 
1874
  ):
 
 
 
 
 
 
1875
  if _apply_event(state, event):
1876
  status = "✓ Completed" if event.get("type") == "__final__" else "⚙ Running…"
1877
- yield _render_terminal_state(state), status
1878
 
1879
  except Exception as exc:
1880
  _apply_event(state, {"type": "error", "error": str(exc)})
1881
- yield _render_terminal_state(state), f"⚠ {exc}"
1882
  return
1883
- finally:
1884
- try:
1885
- p = cwd.parent
1886
- if p.name.startswith("pi_corpus_"):
1887
- shutil.rmtree(p, ignore_errors=True)
1888
- except Exception:
1889
- pass
1890
 
1891
  if not state.get("finalized"):
1892
  _apply_event(state, {"type": "__final__", "text": ""})
1893
- yield _render_terminal_state(state), "✓ Completed"
1894
 
1895
 
1896
  # ──────────────────────────────────────────────────────────────────────────
@@ -1899,6 +1942,7 @@ def run_search(
1899
  def build_ui() -> gr.Blocks:
1900
  hero_logo_src = f"/gradio_api/file={HERO_LOGO_PATH}"
1901
  with gr.Blocks(title="DCI-Agent Search") as demo:
 
1902
  with gr.Row(equal_height=False):
1903
 
1904
  # ── Left sidebar ───────────────────────────────────────────────
@@ -2053,28 +2097,30 @@ def build_ui() -> gr.Blocks:
2053
  outputs=[uploaded_zip, corpus_hint],
2054
  )
2055
 
2056
- def clear_all():
 
2057
  return (
2058
- _TERM_IDLE, "", "Ready ✔",
2059
  )
2060
 
2061
  clear_btn.click(
2062
  fn=clear_all,
2063
- outputs=[terminal, terminal_question, status_text],
 
2064
  )
2065
 
2066
- def _run_with_qa_mode(api_key, question, model, max_turns, corpus_source, uploaded_zip):
2067
- yield from run_search(api_key, question, model, max_turns, "QA", corpus_source, uploaded_zip)
2068
 
2069
  run_event = run_btn.click(
2070
  fn=_run_with_qa_mode,
2071
- inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip],
2072
- outputs=[terminal, status_text],
2073
  )
2074
  terminal_question.submit(
2075
  fn=_run_with_qa_mode,
2076
- inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip],
2077
- outputs=[terminal, status_text],
2078
  )
2079
  stop_btn.click(fn=None, cancels=[run_event])
2080
 
 
1230
  return candidate if candidate.exists() else corpus_dir
1231
 
1232
 
1233
+ def _cleanup_runtime_state(runtime_state: Optional[Dict[str, Any]]) -> Dict[str, Any]:
1234
+ state = dict(runtime_state or {})
1235
+ corpus_root = state.get("corpus_root")
1236
+ session_dir = state.get("session_dir")
1237
+ for path_value in (corpus_root, session_dir):
1238
+ if not path_value:
1239
+ continue
1240
+ try:
1241
+ path = Path(path_value)
1242
+ if path.exists():
1243
+ shutil.rmtree(path, ignore_errors=True)
1244
+ except Exception:
1245
+ pass
1246
+ return {}
1247
+
1248
+
1249
  def _ensure_corpus_ready(
1250
  use_default: bool,
1251
  uploaded_zip: Optional[str],
 
1852
  mode: str,
1853
  corpus_source: str,
1854
  uploaded_zip: Optional[str],
1855
+ runtime_state: Optional[Dict[str, Any]],
1856
+ ) -> Generator[tuple[str, str, Dict[str, Any]], None, None]:
1857
+ """Yields (terminal_html, status_str, runtime_state)."""
1858
 
1859
  if not api_key or not api_key.strip():
1860
+ yield _TERM_IDLE, "⚠ OpenAI API Key is required.", dict(runtime_state or {})
1861
  return
1862
  if not question or not question.strip():
1863
+ yield _TERM_IDLE, "⚠ Please enter a question.", dict(runtime_state or {})
1864
  return
1865
 
1866
+ runtime = dict(runtime_state or {})
1867
  use_default_corpus = corpus_source == DEFAULT_CORPUS_LABEL
1868
  selected_domain = _selected_example_corpus(question) if use_default_corpus else None
1869
+ corpus_key = json.dumps(
1870
+ {
1871
+ "source": corpus_source,
1872
+ "uploaded_zip": uploaded_zip or "",
1873
+ "selected_domain": selected_domain or "",
1874
+ },
1875
+ ensure_ascii=False,
1876
+ sort_keys=True,
1877
+ )
1878
+ corpus_root_value = runtime.get("corpus_root")
1879
+ corpus_root = Path(corpus_root_value) if corpus_root_value else None
1880
+ if runtime.get("corpus_key") != corpus_key or corpus_root is None or not corpus_root.exists():
1881
+ runtime = _cleanup_runtime_state(runtime)
1882
+ try:
1883
+ corpus_root = _ensure_corpus_ready(use_default_corpus, uploaded_zip, selected_domain)
1884
+ except ValueError as exc:
1885
+ yield _TERM_IDLE, f"⚠ {exc}", runtime
1886
+ return
1887
+ except Exception as exc:
1888
+ yield _TERM_IDLE, f"⚠ Corpus error: {exc}", runtime
1889
+ return
1890
+ session_dir = Path(tempfile.mkdtemp(prefix="pi_session_"))
1891
+ runtime.update(
1892
+ {
1893
+ "corpus_key": corpus_key,
1894
+ "corpus_root": str(corpus_root),
1895
+ "session_dir": str(session_dir),
1896
+ "session_file": "",
1897
+ "session_id": "",
1898
+ }
1899
+ )
1900
 
1901
+ assert corpus_root is not None
1902
  cwd = _resolve_example_corpus(corpus_root, question) if use_default_corpus else corpus_root
1903
 
1904
  prompt = build_ir_prompt(question, cwd) if mode == "IR" else build_benchmark_prompt(question, cwd)
1905
  state = _initial_terminal_state(question, cwd, mode, model, max_turns, selected_domain)
1906
+ yield _render_terminal_state(state), "⚙ Running…", runtime
1907
 
1908
  try:
1909
  for event in run_pi_stream(
 
1913
  provider="openai",
1914
  model=model,
1915
  max_turns=max_turns,
1916
+ session_dir=Path(runtime["session_dir"]) if runtime.get("session_dir") else None,
1917
+ session_path=Path(runtime["session_file"]) if runtime.get("session_file") else None,
1918
  ):
1919
+ if event.get("type") == "__session__":
1920
+ if event.get("sessionFile"):
1921
+ runtime["session_file"] = str(event["sessionFile"])
1922
+ if event.get("sessionId"):
1923
+ runtime["session_id"] = str(event["sessionId"])
1924
+ continue
1925
  if _apply_event(state, event):
1926
  status = "✓ Completed" if event.get("type") == "__final__" else "⚙ Running…"
1927
+ yield _render_terminal_state(state), status, runtime
1928
 
1929
  except Exception as exc:
1930
  _apply_event(state, {"type": "error", "error": str(exc)})
1931
+ yield _render_terminal_state(state), f"⚠ {exc}", runtime
1932
  return
 
 
 
 
 
 
 
1933
 
1934
  if not state.get("finalized"):
1935
  _apply_event(state, {"type": "__final__", "text": ""})
1936
+ yield _render_terminal_state(state), "✓ Completed", runtime
1937
 
1938
 
1939
  # ──────────────────────────────────────────────────────────────────────────
 
1942
  def build_ui() -> gr.Blocks:
1943
  hero_logo_src = f"/gradio_api/file={HERO_LOGO_PATH}"
1944
  with gr.Blocks(title="DCI-Agent Search") as demo:
1945
+ runtime_state = gr.State({})
1946
  with gr.Row(equal_height=False):
1947
 
1948
  # ── Left sidebar ───────────────────────────────────────────────
 
2097
  outputs=[uploaded_zip, corpus_hint],
2098
  )
2099
 
2100
+ def clear_all(runtime_state):
2101
+ cleared_runtime = _cleanup_runtime_state(runtime_state)
2102
  return (
2103
+ _TERM_IDLE, "", "Ready ✔", cleared_runtime,
2104
  )
2105
 
2106
  clear_btn.click(
2107
  fn=clear_all,
2108
+ inputs=[runtime_state],
2109
+ outputs=[terminal, terminal_question, status_text, runtime_state],
2110
  )
2111
 
2112
+ def _run_with_qa_mode(api_key, question, model, max_turns, corpus_source, uploaded_zip, runtime_state):
2113
+ yield from run_search(api_key, question, model, max_turns, "QA", corpus_source, uploaded_zip, runtime_state)
2114
 
2115
  run_event = run_btn.click(
2116
  fn=_run_with_qa_mode,
2117
+ inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip, runtime_state],
2118
+ outputs=[terminal, status_text, runtime_state],
2119
  )
2120
  terminal_question.submit(
2121
  fn=_run_with_qa_mode,
2122
+ inputs=[api_key, terminal_question, model, max_turns, corpus_source, uploaded_zip, runtime_state],
2123
+ outputs=[terminal, status_text, runtime_state],
2124
  )
2125
  stop_btn.click(fn=None, cancels=[run_event])
2126
 
pi_wrapper.py CHANGED
@@ -93,6 +93,8 @@ class NativePiClient:
93
  provider: str,
94
  model: str,
95
  tools: str,
 
 
96
  api_key: str = "",
97
  ) -> None:
98
  self.package_dir = package_dir
@@ -101,6 +103,8 @@ class NativePiClient:
101
  self.provider = provider
102
  self.model = model
103
  self.tools = tools
 
 
104
  self.api_key = api_key
105
  self.proc: Optional[subprocess.Popen[bytes]] = None
106
  self.stderr_chunks: List[str] = []
@@ -132,7 +136,10 @@ class NativePiClient:
132
  cmd.extend(["--model", self.model])
133
  if self.tools:
134
  cmd.extend(["--tools", self.tools])
135
- cmd.append("--no-session")
 
 
 
136
  return cmd
137
 
138
  def start(self) -> None:
@@ -185,6 +192,17 @@ class NativePiClient:
185
  self.proc.stdin.write(line.encode("utf-8"))
186
  self.proc.stdin.flush()
187
 
 
 
 
 
 
 
 
 
 
 
 
188
  def _read_json_line(self) -> Dict[str, Any]:
189
  if self.proc is None or self.proc.stdout is None:
190
  raise RuntimeError("RPC client is not running")
@@ -624,6 +642,8 @@ def run_pi(
624
  provider: str = "openai",
625
  model: str = "gpt-4o",
626
  max_turns: int = 6,
 
 
627
  ) -> Generator[Dict[str, Any], None, str]:
628
  """
629
  Run PI (native if available, else fallback) and yield events.
@@ -637,6 +657,8 @@ def run_pi(
637
  provider=provider,
638
  model=model,
639
  tools="read,bash",
 
 
640
  api_key=api_key,
641
  )
642
  try:
@@ -666,6 +688,8 @@ def run_pi_stream(
666
  provider: str = "openai",
667
  model: str = "gpt-4o",
668
  max_turns: int = 6,
 
 
669
  ) -> Iterator[Dict[str, Any]]:
670
  """
671
  Iterator that yields PI events. The *last* item is a sentinel dict
@@ -679,11 +703,20 @@ def run_pi_stream(
679
  provider=provider,
680
  model=model,
681
  tools="read,bash",
 
 
682
  api_key=api_key,
683
  )
684
  text_parts: List[str] = []
685
  try:
686
  client.start()
 
 
 
 
 
 
 
687
  request_id = client._next_id()
688
  client._send({"id": request_id, "type": "prompt", "message": question})
689
 
 
93
  provider: str,
94
  model: str,
95
  tools: str,
96
+ session_dir: Optional[Path] = None,
97
+ session_path: Optional[Path] = None,
98
  api_key: str = "",
99
  ) -> None:
100
  self.package_dir = package_dir
 
103
  self.provider = provider
104
  self.model = model
105
  self.tools = tools
106
+ self.session_dir = session_dir
107
+ self.session_path = session_path
108
  self.api_key = api_key
109
  self.proc: Optional[subprocess.Popen[bytes]] = None
110
  self.stderr_chunks: List[str] = []
 
136
  cmd.extend(["--model", self.model])
137
  if self.tools:
138
  cmd.extend(["--tools", self.tools])
139
+ if self.session_dir:
140
+ cmd.extend(["--session-dir", str(self.session_dir)])
141
+ if self.session_path:
142
+ cmd.extend(["--session", str(self.session_path)])
143
  return cmd
144
 
145
  def start(self) -> None:
 
192
  self.proc.stdin.write(line.encode("utf-8"))
193
  self.proc.stdin.flush()
194
 
195
+ def call(self, command_type: str, **payload: Any) -> Dict[str, Any]:
196
+ request_id = self._next_id()
197
+ message = {"id": request_id, "type": command_type, **payload}
198
+ self._send(message)
199
+ while True:
200
+ event = self._read_json_line()
201
+ if event.get("type") == "response" and event.get("id") == request_id:
202
+ if not event.get("success", False):
203
+ raise RuntimeError(f"RPC {command_type} failed: {event.get('error', 'unknown error')}")
204
+ return event
205
+
206
  def _read_json_line(self) -> Dict[str, Any]:
207
  if self.proc is None or self.proc.stdout is None:
208
  raise RuntimeError("RPC client is not running")
 
642
  provider: str = "openai",
643
  model: str = "gpt-4o",
644
  max_turns: int = 6,
645
+ session_dir: Optional[Path] = None,
646
+ session_path: Optional[Path] = None,
647
  ) -> Generator[Dict[str, Any], None, str]:
648
  """
649
  Run PI (native if available, else fallback) and yield events.
 
657
  provider=provider,
658
  model=model,
659
  tools="read,bash",
660
+ session_dir=session_dir,
661
+ session_path=session_path,
662
  api_key=api_key,
663
  )
664
  try:
 
688
  provider: str = "openai",
689
  model: str = "gpt-4o",
690
  max_turns: int = 6,
691
+ session_dir: Optional[Path] = None,
692
+ session_path: Optional[Path] = None,
693
  ) -> Iterator[Dict[str, Any]]:
694
  """
695
  Iterator that yields PI events. The *last* item is a sentinel dict
 
703
  provider=provider,
704
  model=model,
705
  tools="read,bash",
706
+ session_dir=session_dir,
707
+ session_path=session_path,
708
  api_key=api_key,
709
  )
710
  text_parts: List[str] = []
711
  try:
712
  client.start()
713
+ state_response = client.call("get_state")
714
+ state_data = state_response.get("data", {}) or {}
715
+ yield {
716
+ "type": "__session__",
717
+ "sessionFile": state_data.get("sessionFile"),
718
+ "sessionId": state_data.get("sessionId"),
719
+ }
720
  request_id = client._next_id()
721
  client._send({"id": request_id, "type": "prompt", "message": question})
722