davidkim205 commited on
Commit
72afe31
·
1 Parent(s): 9d7364f

진행 과정 출력하도록 코드 수정

Browse files
Files changed (2) hide show
  1. api_server.py +114 -29
  2. gradio_app.py +21 -9
api_server.py CHANGED
@@ -1,4 +1,6 @@
1
  import json
 
 
2
  from queue import Empty, Queue
3
  from threading import Thread
4
  from typing import Optional
@@ -14,6 +16,71 @@ from persona.make_persona import make_persona
14
  app = FastAPI()
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class PersonaRequest(BaseModel):
18
  info: str
19
 
@@ -44,6 +111,21 @@ def _sse(payload: dict) -> str:
44
  return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  @app.post("/analyze/")
48
  async def analyze(request: QueryRequest):
49
  query = (request.query or "").strip()
@@ -52,25 +134,35 @@ async def analyze(request: QueryRequest):
52
  persona_name = (request.persona_name or "").strip() or None
53
 
54
  if not stream:
55
- result = run_pipeline(
56
- query,
57
- persona_name=persona_name,
58
- status_callback=None,
59
- stream_callback=None,
60
- stream=False,
61
- )
62
- return JSONResponse(
63
- content=jsonable_encoder(
64
- {
65
- "type": "result",
66
- "query": result.query,
67
- "ticker": result.ticker,
68
- "analysis_type": result.analysis_type,
69
- "data_context": result.data_context,
70
- "llm_response": result.llm_response,
71
- "timestamp": getattr(result, "timestamp", None),
72
- }
 
 
 
 
 
 
73
  )
 
 
 
 
74
  )
75
 
76
  def event_stream():
@@ -84,6 +176,8 @@ async def analyze(request: QueryRequest):
84
  event_queue.put({"type": "delta", "delta": delta})
85
 
86
  def worker():
 
 
87
  try:
88
  result = run_pipeline(
89
  query,
@@ -92,20 +186,11 @@ async def analyze(request: QueryRequest):
92
  stream_callback=on_delta if stream else None,
93
  stream=stream,
94
  )
95
- event_queue.put(
96
- {
97
- "type": "result",
98
- "query": result.query,
99
- "ticker": result.ticker,
100
- "analysis_type": result.analysis_type,
101
- "data_context": result.data_context,
102
- "llm_response": result.llm_response,
103
- "timestamp": getattr(result, "timestamp", None),
104
- }
105
- )
106
  except Exception as exc:
107
  event_queue.put({"type": "error", "message": str(exc)})
108
  finally:
 
109
  event_queue.put({"type": "done"})
110
 
111
  yield _sse({"type": "status", "message": "요청 수신. 분석 준비 중..."})
 
1
  import json
2
+ import sys
3
+ import threading
4
  from queue import Empty, Queue
5
  from threading import Thread
6
  from typing import Optional
 
16
  app = FastAPI()
17
 
18
 
19
+ class _ThreadStdoutProxy:
20
+ def __init__(self, target):
21
+ self._target = target
22
+ self._handlers = {}
23
+ self._lock = threading.RLock()
24
+ self.encoding = getattr(target, "encoding", "utf-8")
25
+ self.errors = getattr(target, "errors", None)
26
+
27
+ def register(self, thread_id: int, handler) -> None:
28
+ with self._lock:
29
+ self._handlers[thread_id] = handler
30
+
31
+ def unregister(self, thread_id: int) -> None:
32
+ with self._lock:
33
+ self._handlers.pop(thread_id, None)
34
+
35
+ def _resolve(self):
36
+ thread_id = threading.get_ident()
37
+ with self._lock:
38
+ return self._handlers.get(thread_id), self._target
39
+
40
+ def write(self, data):
41
+ handler, target = self._resolve()
42
+ if handler:
43
+ return handler.write(data)
44
+ return target.write(data)
45
+
46
+ def flush(self):
47
+ handler, target = self._resolve()
48
+ if handler:
49
+ handler.flush()
50
+ return target.flush()
51
+
52
+ def isatty(self):
53
+ return getattr(self._target, "isatty", lambda: False)()
54
+
55
+ def fileno(self):
56
+ return self._target.fileno()
57
+
58
+ def writable(self):
59
+ return True
60
+
61
+ def __getattr__(self, name):
62
+ return getattr(self._target, name)
63
+
64
+
65
+ class _QueueingStdoutTee:
66
+ def __init__(self, target, event_queue: Queue):
67
+ self._target = target
68
+ self._event_queue = event_queue
69
+
70
+ def write(self, data):
71
+ written = self._target.write(data)
72
+ if data:
73
+ self._event_queue.put({"type": "stdout", "message": data})
74
+ return written
75
+
76
+ def flush(self):
77
+ self._target.flush()
78
+
79
+
80
+ _stdout_proxy = _ThreadStdoutProxy(sys.stdout)
81
+ sys.stdout = _stdout_proxy
82
+
83
+
84
  class PersonaRequest(BaseModel):
85
  info: str
86
 
 
111
  return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
112
 
113
 
114
+ def _build_result_payload(result, stdout: str = "") -> dict:
115
+ payload = {
116
+ "type": "result",
117
+ "query": result.query,
118
+ "ticker": result.ticker,
119
+ "analysis_type": result.analysis_type,
120
+ "data_context": result.data_context,
121
+ "llm_response": result.llm_response,
122
+ "timestamp": getattr(result, "timestamp", None),
123
+ }
124
+ if stdout:
125
+ payload["stdout"] = stdout
126
+ return payload
127
+
128
+
129
  @app.post("/analyze/")
130
  async def analyze(request: QueryRequest):
131
  query = (request.query or "").strip()
 
134
  persona_name = (request.persona_name or "").strip() or None
135
 
136
  if not stream:
137
+ stdout_messages = []
138
+
139
+ class _ListStdoutTee:
140
+ def __init__(self, target):
141
+ self._target = target
142
+
143
+ def write(self, data):
144
+ written = self._target.write(data)
145
+ if data:
146
+ stdout_messages.append(data)
147
+ return written
148
+
149
+ def flush(self):
150
+ self._target.flush()
151
+
152
+ thread_id = threading.get_ident()
153
+ _stdout_proxy.register(thread_id, _ListStdoutTee(_stdout_proxy._target))
154
+ try:
155
+ result = run_pipeline(
156
+ query,
157
+ persona_name=persona_name,
158
+ status_callback=None,
159
+ stream_callback=None,
160
+ stream=False,
161
  )
162
+ finally:
163
+ _stdout_proxy.unregister(thread_id)
164
+ return JSONResponse(
165
+ content=jsonable_encoder(_build_result_payload(result, stdout="".join(stdout_messages)))
166
  )
167
 
168
  def event_stream():
 
176
  event_queue.put({"type": "delta", "delta": delta})
177
 
178
  def worker():
179
+ thread_id = threading.get_ident()
180
+ _stdout_proxy.register(thread_id, _QueueingStdoutTee(_stdout_proxy._target, event_queue))
181
  try:
182
  result = run_pipeline(
183
  query,
 
186
  stream_callback=on_delta if stream else None,
187
  stream=stream,
188
  )
189
+ event_queue.put(_build_result_payload(result))
 
 
 
 
 
 
 
 
 
 
190
  except Exception as exc:
191
  event_queue.put({"type": "error", "message": str(exc)})
192
  finally:
193
+ _stdout_proxy.unregister(thread_id)
194
  event_queue.put({"type": "done"})
195
 
196
  yield _sse({"type": "status", "message": "요청 수신. 분석 준비 중..."})
gradio_app.py CHANGED
@@ -193,7 +193,9 @@ def stream_analyze(query, persona_name, endpoint):
193
  return
194
 
195
  text_acc = ""
 
196
  meta_text = ""
 
197
  first_delta_received = False
198
  loading_msg = "요청 중..."
199
  worker_finished = False
@@ -202,6 +204,14 @@ def stream_analyze(query, persona_name, endpoint):
202
 
203
  event_queue = Queue()
204
 
 
 
 
 
 
 
 
 
205
  def reader_worker():
206
  try:
207
  payload = {"query": query}
@@ -271,13 +281,20 @@ def stream_analyze(query, persona_name, endpoint):
271
 
272
  elif event_type == "result":
273
  result = payload.get("result", payload)
274
- meta_text = json.dumps(result, ensure_ascii=False, indent=2)
 
275
  if not text_acc:
276
  llm_response = result.get("llm_response", "")
277
  if llm_response:
278
  first_delta_received = True
279
  text_acc = llm_response
280
 
 
 
 
 
 
 
281
  elif event_type == "error":
282
  message = payload.get("message", "알 수 없는 오류")
283
  if text_acc:
@@ -417,7 +434,7 @@ def create_app(default_endpoint):
417
 
418
  #answer-wrapper {
419
  min-height: 420px;
420
- max-height: 62vh;
421
  overflow-y: auto !important;
422
  border: 1px solid var(--ws-border) !important;
423
  border-radius: 14px !important;
@@ -479,12 +496,7 @@ def create_app(default_endpoint):
479
  border: 1px solid #d9e2ec;
480
  border-radius: 6px;
481
  padding: 0.1em 0.35em;
482
- font-size: 0.92em;
483
- font-family: "JetBrains Mono", "IBM Plex Mono", "Source Code Pro", monospace !important;
484
- letter-spacing: 0.01em;
485
- }
486
-
487
- #answer-wrapper pre {
488
  background: #0f172a !important;
489
  color: #e2e8f0 !important;
490
  border-radius: 10px;
@@ -596,7 +608,7 @@ def create_app(default_endpoint):
596
 
597
  answer = gr.Markdown(value=to_markdown(""), label="답변", elem_id="answer-wrapper")
598
  timer = gr.Markdown(value=timer_text("0.0초"), elem_id="timer-row")
599
- meta = gr.Code(label="최종 (JSON)", language="json", elem_id="meta-box")
600
 
601
  gr.HTML(AUTO_SCROLL_SCRIPT, visible=False)
602
 
 
193
  return
194
 
195
  text_acc = ""
196
+ result_meta_text = ""
197
  meta_text = ""
198
+ stdout_acc = ""
199
  first_delta_received = False
200
  loading_msg = "요청 중..."
201
  worker_finished = False
 
204
 
205
  event_queue = Queue()
206
 
207
+ def build_meta_text():
208
+ sections = []
209
+ if result_meta_text:
210
+ sections.append(result_meta_text)
211
+ if stdout_acc:
212
+ sections.append(f"[stdout]\n{stdout_acc}")
213
+ return "\n\n".join(sections)
214
+
215
  def reader_worker():
216
  try:
217
  payload = {"query": query}
 
281
 
282
  elif event_type == "result":
283
  result = payload.get("result", payload)
284
+ result_meta_text = json.dumps(result, ensure_ascii=False, indent=2)
285
+ meta_text = build_meta_text()
286
  if not text_acc:
287
  llm_response = result.get("llm_response", "")
288
  if llm_response:
289
  first_delta_received = True
290
  text_acc = llm_response
291
 
292
+ elif event_type == "stdout":
293
+ message = payload.get("message", "")
294
+ if message:
295
+ stdout_acc += message
296
+ meta_text = build_meta_text()
297
+
298
  elif event_type == "error":
299
  message = payload.get("message", "알 수 없는 오류")
300
  if text_acc:
 
434
 
435
  #answer-wrapper {
436
  min-height: 420px;
437
+ max-height: 42vh;
438
  overflow-y: auto !important;
439
  border: 1px solid var(--ws-border) !important;
440
  border-radius: 14px !important;
 
496
  border: 1px solid #d9e2ec;
497
  border-radius: 6px;
498
  padding: 0.1em 0.35em;
499
+ font-size: 0.92em;answer-wrapper
 
 
 
 
 
500
  background: #0f172a !important;
501
  color: #e2e8f0 !important;
502
  border-radius: 10px;
 
608
 
609
  answer = gr.Markdown(value=to_markdown(""), label="답변", elem_id="answer-wrapper")
610
  timer = gr.Markdown(value=timer_text("0.0초"), elem_id="timer-row")
611
+ meta = gr.Code(label="진행 출력", language="json", elem_id="meta-box")
612
 
613
  gr.HTML(AUTO_SCROLL_SCRIPT, visible=False)
614