ccm commited on
Commit
2fabb0d
·
1 Parent(s): 52a7b5c

Moving more things over into agent_server

Browse files
Files changed (2) hide show
  1. agent_server/agent_streaming.py +244 -0
  2. proxy.py +1 -247
agent_server/agent_streaming.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import contextlib
4
+ import threading
5
+ import time
6
+ import typing
7
+
8
+ import fastapi
9
+ import httpx
10
+
11
+ from agent_server.helpers import _sse_headers
12
+ from agent_server.sanitizing_think_tags import scrub_think_tags
13
+ from agent_server.std_tee import QueueWriter, _serialize_step
14
+
15
+
16
+ async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None):
17
+ """
18
+ Start the agent in a worker thread.
19
+ Stream THREE sources of incremental data into the async generator:
20
+ (1) live stdout/stderr lines,
21
+ (2) newly appended memory steps (polled),
22
+ (3) any iterable the agent may yield (if supported).
23
+ Finally emit a __final__ item with the last answer.
24
+ """
25
+ loop = asyncio.get_running_loop()
26
+ q: asyncio.Queue = asyncio.Queue()
27
+ agent_to_use = agent_obj
28
+
29
+ stop_evt = threading.Event()
30
+
31
+ # 1) stdout/stderr live tee
32
+ qwriter = QueueWriter(q)
33
+
34
+ # 2) memory poller
35
+ def poll_memory():
36
+ last_len = 0
37
+ while not stop_evt.is_set():
38
+ try:
39
+ steps = []
40
+ try:
41
+ # Common API: agent.memory.get_full_steps()
42
+ steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined]
43
+ except Exception:
44
+ # Fallbacks: different names across versions
45
+ steps = (
46
+ getattr(agent_to_use, "steps", [])
47
+ or getattr(agent_to_use, "memory", [])
48
+ or []
49
+ )
50
+ if steps is None:
51
+ steps = []
52
+ curr_len = len(steps)
53
+ if curr_len > last_len:
54
+ new = steps[last_len:curr_len]
55
+ last_len = curr_len
56
+ for s in new:
57
+ s_text = _serialize_step(s)
58
+ if s_text:
59
+ try:
60
+ q.put_nowait({"__step__": s_text})
61
+ except Exception:
62
+ pass
63
+ except Exception:
64
+ pass
65
+ time.sleep(0.10) # 100 ms cadence
66
+
67
+ # 3) agent runner (may or may not yield)
68
+ def run_agent():
69
+ final_result = None
70
+ try:
71
+ with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr(
72
+ qwriter
73
+ ):
74
+ used_iterable = False
75
+ if hasattr(agent_to_use, "run") and callable(
76
+ getattr(agent_to_use, "run")
77
+ ):
78
+ try:
79
+ res = agent_to_use.run(task, stream=True)
80
+ if hasattr(res, "__iter__") and not isinstance(
81
+ res, (str, bytes)
82
+ ):
83
+ used_iterable = True
84
+ for it in res:
85
+ try:
86
+ q.put_nowait(it)
87
+ except Exception:
88
+ pass
89
+ final_result = (
90
+ None # iterable may already contain the answer
91
+ )
92
+ else:
93
+ final_result = res
94
+ except TypeError:
95
+ # run(stream=True) not supported -> fall back
96
+ pass
97
+
98
+ if final_result is None and not used_iterable:
99
+ # Try other common streaming signatures
100
+ for name in (
101
+ "run_stream",
102
+ "stream",
103
+ "stream_run",
104
+ "run_with_callback",
105
+ ):
106
+ fn = getattr(agent_to_use, name, None)
107
+ if callable(fn):
108
+ try:
109
+ res = fn(task)
110
+ if hasattr(res, "__iter__") and not isinstance(
111
+ res, (str, bytes)
112
+ ):
113
+ for it in res:
114
+ q.put_nowait(it)
115
+ final_result = None
116
+ else:
117
+ final_result = res
118
+ break
119
+ except TypeError:
120
+ # maybe callback signature
121
+ def cb(item):
122
+ try:
123
+ q.put_nowait(item)
124
+ except Exception:
125
+ pass
126
+
127
+ try:
128
+ fn(task, cb)
129
+ final_result = None
130
+ break
131
+ except Exception:
132
+ continue
133
+
134
+ if final_result is None and not used_iterable:
135
+ pass # (typo guard removed below)
136
+
137
+ if final_result is None and not used_iterable:
138
+ # Last resort: synchronous run()/generate()/callable
139
+ if hasattr(agent_to_use, "run") and callable(
140
+ getattr(agent_to_use, "run")
141
+ ):
142
+ final_result = agent_to_use.run(task)
143
+ elif hasattr(agent_to_use, "generate") and callable(
144
+ getattr(agent_to_use, "generate")
145
+ ):
146
+ final_result = agent_to_use.generate(task)
147
+ elif callable(agent_to_use):
148
+ final_result = agent_to_use(task)
149
+
150
+ except Exception as e:
151
+ try:
152
+ qwriter.flush()
153
+ except Exception:
154
+ pass
155
+ try:
156
+ q.put_nowait({"__error__": str(e)})
157
+ except Exception:
158
+ pass
159
+ finally:
160
+ try:
161
+ qwriter.flush()
162
+ except Exception:
163
+ pass
164
+ try:
165
+ q.put_nowait({"__final__": final_result})
166
+ except Exception:
167
+ pass
168
+ stop_evt.set()
169
+
170
+ # Kick off threads
171
+ mem_thread = threading.Thread(target=poll_memory, daemon=True)
172
+ run_thread = threading.Thread(target=run_agent, daemon=True)
173
+ mem_thread.start()
174
+ run_thread.start()
175
+
176
+ # Async consumer
177
+ while True:
178
+ item = await q.get()
179
+ yield item
180
+ if isinstance(item, dict) and "__final__" in item:
181
+ break
182
+
183
+
184
+ def _recursively_scrub(obj):
185
+ if isinstance(obj, str):
186
+ return scrub_think_tags(obj)
187
+ if isinstance(obj, dict):
188
+ return {k: _recursively_scrub(v) for k, v in obj.items()}
189
+ if isinstance(obj, list):
190
+ return [_recursively_scrub(v) for v in obj]
191
+ return obj
192
+
193
+
194
+ async def _proxy_upstream_chat_completions(
195
+ body: dict, stream: bool, scrub_think: bool = False
196
+ ):
197
+ HF_TOKEN=os.getenv("OPENAI_API_KEY")
198
+ headers = {
199
+ "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "",
200
+ "Content-Type": "application/json",
201
+ }
202
+ UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/")
203
+ url = f"{UPSTREAM_BASE}/chat/completions"
204
+
205
+ if stream:
206
+
207
+ async def proxy_stream():
208
+ async with httpx.AsyncClient(timeout=None) as client:
209
+ async with client.stream(
210
+ "POST", url, headers=headers, json=body
211
+ ) as resp:
212
+ resp.raise_for_status()
213
+ if scrub_think:
214
+ # Pull text segments, scrub tags, and yield bytes
215
+ async for txt in resp.aiter_text():
216
+ try:
217
+ cleaned = scrub_think_tags(txt)
218
+ yield cleaned.encode("utf-8")
219
+ except Exception:
220
+ yield txt.encode("utf-8")
221
+ else:
222
+ async for chunk in resp.aiter_bytes():
223
+ yield chunk
224
+
225
+ return fastapi.responses.StreamingResponse(
226
+ proxy_stream(), media_type="text/event-stream", headers=_sse_headers()
227
+ )
228
+ else:
229
+ async with httpx.AsyncClient(timeout=None) as client:
230
+ r = await client.post(url, headers=headers, json=body)
231
+ try:
232
+ payload = r.json()
233
+ except Exception:
234
+ payload = {"status_code": r.status_code, "text": r.text}
235
+
236
+ if scrub_think:
237
+ try:
238
+ payload = _recursively_scrub(payload)
239
+ except Exception:
240
+ pass
241
+
242
+ return fastapi.responses.JSONResponse(
243
+ status_code=r.status_code, content=payload
244
+ )
proxy.py CHANGED
@@ -8,21 +8,17 @@ import time # For timestamps and sleeps
8
  import asyncio # For async operations
9
  import typing # For type annotations
10
  import logging # For logging
11
- import threading # For threading operations
12
 
13
  import fastapi
14
  import fastapi.responses
15
- import contextlib
16
 
17
  # Upstream pass-through
18
- import httpx
19
-
20
  from agent_server.formatting_reasoning import _format_reasoning_chunk, _extract_final_text, \
21
  _maybe_parse_final_from_stdout
22
  from agent_server.helpers import normalize_content_to_text, _messages_to_task, _openai_response, _sse_headers
23
  from agent_server.openai_schemas import ChatMessage, ChatCompletionRequest
24
  from agent_server.sanitizing_think_tags import scrub_think_tags
25
- from agent_server.std_tee import QueueWriter, _serialize_step
26
  from agents.code_writing_agents import (
27
  generate_code_writing_agent_without_tools,
28
  generate_code_writing_agent_with_search,
@@ -38,18 +34,8 @@ from agents.generator_and_critic import generate_generator_with_managed_critic
38
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
39
  log = logging.getLogger(__name__)
40
 
41
- # Config from env vars
42
- UPSTREAM_BASE = os.getenv("UPSTREAM_OPENAI_BASE", "").rstrip("/")
43
- HF_TOKEN = os.getenv("OPENAI_API_KEY")
44
  AGENT_MODEL = os.getenv("AGENT_MODEL", "Qwen/Qwen3-1.7B")
45
 
46
- if not UPSTREAM_BASE:
47
- log.warning(
48
- "UPSTREAM_OPENAI_BASE is empty; OpenAI-compatible upstream calls will fail."
49
- )
50
- if not HF_TOKEN:
51
- log.warning("HF_TOKEN is empty; upstream may 401/403 if it requires auth.")
52
-
53
  # ================== FastAPI ==================
54
  app = fastapi.FastAPI()
55
 
@@ -59,238 +45,6 @@ async def healthz():
59
  return {"ok": True}
60
 
61
  # ---------- Agent streaming bridge (truly live) ----------
62
- async def run_agent_stream(task: str, agent_obj: typing.Optional[typing.Any] = None):
63
- """
64
- Start the agent in a worker thread.
65
- Stream THREE sources of incremental data into the async generator:
66
- (1) live stdout/stderr lines,
67
- (2) newly appended memory steps (polled),
68
- (3) any iterable the agent may yield (if supported).
69
- Finally emit a __final__ item with the last answer.
70
- """
71
- loop = asyncio.get_running_loop()
72
- q: asyncio.Queue = asyncio.Queue()
73
- agent_to_use = agent_obj
74
-
75
- stop_evt = threading.Event()
76
-
77
- # 1) stdout/stderr live tee
78
- qwriter = QueueWriter(q)
79
-
80
- # 2) memory poller
81
- def poll_memory():
82
- last_len = 0
83
- while not stop_evt.is_set():
84
- try:
85
- steps = []
86
- try:
87
- # Common API: agent.memory.get_full_steps()
88
- steps = agent_to_use.memory.get_full_steps() # type: ignore[attr-defined]
89
- except Exception:
90
- # Fallbacks: different names across versions
91
- steps = (
92
- getattr(agent_to_use, "steps", [])
93
- or getattr(agent_to_use, "memory", [])
94
- or []
95
- )
96
- if steps is None:
97
- steps = []
98
- curr_len = len(steps)
99
- if curr_len > last_len:
100
- new = steps[last_len:curr_len]
101
- last_len = curr_len
102
- for s in new:
103
- s_text = _serialize_step(s)
104
- if s_text:
105
- try:
106
- q.put_nowait({"__step__": s_text})
107
- except Exception:
108
- pass
109
- except Exception:
110
- pass
111
- time.sleep(0.10) # 100 ms cadence
112
-
113
- # 3) agent runner (may or may not yield)
114
- def run_agent():
115
- final_result = None
116
- try:
117
- with contextlib.redirect_stdout(qwriter), contextlib.redirect_stderr(
118
- qwriter
119
- ):
120
- used_iterable = False
121
- if hasattr(agent_to_use, "run") and callable(
122
- getattr(agent_to_use, "run")
123
- ):
124
- try:
125
- res = agent_to_use.run(task, stream=True)
126
- if hasattr(res, "__iter__") and not isinstance(
127
- res, (str, bytes)
128
- ):
129
- used_iterable = True
130
- for it in res:
131
- try:
132
- q.put_nowait(it)
133
- except Exception:
134
- pass
135
- final_result = (
136
- None # iterable may already contain the answer
137
- )
138
- else:
139
- final_result = res
140
- except TypeError:
141
- # run(stream=True) not supported -> fall back
142
- pass
143
-
144
- if final_result is None and not used_iterable:
145
- # Try other common streaming signatures
146
- for name in (
147
- "run_stream",
148
- "stream",
149
- "stream_run",
150
- "run_with_callback",
151
- ):
152
- fn = getattr(agent_to_use, name, None)
153
- if callable(fn):
154
- try:
155
- res = fn(task)
156
- if hasattr(res, "__iter__") and not isinstance(
157
- res, (str, bytes)
158
- ):
159
- for it in res:
160
- q.put_nowait(it)
161
- final_result = None
162
- else:
163
- final_result = res
164
- break
165
- except TypeError:
166
- # maybe callback signature
167
- def cb(item):
168
- try:
169
- q.put_nowait(item)
170
- except Exception:
171
- pass
172
-
173
- try:
174
- fn(task, cb)
175
- final_result = None
176
- break
177
- except Exception:
178
- continue
179
-
180
- if final_result is None and not used_iterable:
181
- pass # (typo guard removed below)
182
-
183
- if final_result is None and not used_iterable:
184
- # Last resort: synchronous run()/generate()/callable
185
- if hasattr(agent_to_use, "run") and callable(
186
- getattr(agent_to_use, "run")
187
- ):
188
- final_result = agent_to_use.run(task)
189
- elif hasattr(agent_to_use, "generate") and callable(
190
- getattr(agent_to_use, "generate")
191
- ):
192
- final_result = agent_to_use.generate(task)
193
- elif callable(agent_to_use):
194
- final_result = agent_to_use(task)
195
-
196
- except Exception as e:
197
- try:
198
- qwriter.flush()
199
- except Exception:
200
- pass
201
- try:
202
- q.put_nowait({"__error__": str(e)})
203
- except Exception:
204
- pass
205
- finally:
206
- try:
207
- qwriter.flush()
208
- except Exception:
209
- pass
210
- try:
211
- q.put_nowait({"__final__": final_result})
212
- except Exception:
213
- pass
214
- stop_evt.set()
215
-
216
- # Kick off threads
217
- mem_thread = threading.Thread(target=poll_memory, daemon=True)
218
- run_thread = threading.Thread(target=run_agent, daemon=True)
219
- mem_thread.start()
220
- run_thread.start()
221
-
222
- # Async consumer
223
- while True:
224
- item = await q.get()
225
- yield item
226
- if isinstance(item, dict) and "__final__" in item:
227
- break
228
-
229
-
230
- def _recursively_scrub(obj):
231
- if isinstance(obj, str):
232
- return scrub_think_tags(obj)
233
- if isinstance(obj, dict):
234
- return {k: _recursively_scrub(v) for k, v in obj.items()}
235
- if isinstance(obj, list):
236
- return [_recursively_scrub(v) for v in obj]
237
- return obj
238
-
239
-
240
- async def _proxy_upstream_chat_completions(
241
- body: dict, stream: bool, scrub_think: bool = False
242
- ):
243
- if not UPSTREAM_BASE:
244
- return fastapi.responses.JSONResponse(
245
- {"error": {"message": "UPSTREAM_OPENAI_BASE not configured"}},
246
- status_code=500,
247
- )
248
- headers = {
249
- "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else "",
250
- "Content-Type": "application/json",
251
- }
252
- url = f"{UPSTREAM_BASE}/chat/completions"
253
-
254
- if stream:
255
-
256
- async def proxy_stream():
257
- async with httpx.AsyncClient(timeout=None) as client:
258
- async with client.stream(
259
- "POST", url, headers=headers, json=body
260
- ) as resp:
261
- resp.raise_for_status()
262
- if scrub_think:
263
- # Pull text segments, scrub tags, and yield bytes
264
- async for txt in resp.aiter_text():
265
- try:
266
- cleaned = scrub_think_tags(txt)
267
- yield cleaned.encode("utf-8")
268
- except Exception:
269
- yield txt.encode("utf-8")
270
- else:
271
- async for chunk in resp.aiter_bytes():
272
- yield chunk
273
-
274
- return fastapi.responses.StreamingResponse(
275
- proxy_stream(), media_type="text/event-stream", headers=_sse_headers()
276
- )
277
- else:
278
- async with httpx.AsyncClient(timeout=None) as client:
279
- r = await client.post(url, headers=headers, json=body)
280
- try:
281
- payload = r.json()
282
- except Exception:
283
- payload = {"status_code": r.status_code, "text": r.text}
284
-
285
- if scrub_think:
286
- try:
287
- payload = _recursively_scrub(payload)
288
- except Exception:
289
- pass
290
-
291
- return fastapi.responses.JSONResponse(
292
- status_code=r.status_code, content=payload
293
- )
294
 
295
 
296
  # ---------- Endpoints ----------
 
8
  import asyncio # For async operations
9
  import typing # For type annotations
10
  import logging # For logging
 
11
 
12
  import fastapi
13
  import fastapi.responses
 
14
 
15
  # Upstream pass-through
16
+ from agent_server.agent_streaming import run_agent_stream, _proxy_upstream_chat_completions
 
17
  from agent_server.formatting_reasoning import _format_reasoning_chunk, _extract_final_text, \
18
  _maybe_parse_final_from_stdout
19
  from agent_server.helpers import normalize_content_to_text, _messages_to_task, _openai_response, _sse_headers
20
  from agent_server.openai_schemas import ChatMessage, ChatCompletionRequest
21
  from agent_server.sanitizing_think_tags import scrub_think_tags
 
22
  from agents.code_writing_agents import (
23
  generate_code_writing_agent_without_tools,
24
  generate_code_writing_agent_with_search,
 
34
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO").upper())
35
  log = logging.getLogger(__name__)
36
 
 
 
 
37
  AGENT_MODEL = os.getenv("AGENT_MODEL", "Qwen/Qwen3-1.7B")
38
 
 
 
 
 
 
 
 
39
  # ================== FastAPI ==================
40
  app = fastapi.FastAPI()
41
 
 
45
  return {"ok": True}
46
 
47
  # ---------- Agent streaming bridge (truly live) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  # ---------- Endpoints ----------