| import json |
| import sys |
| import threading |
| from queue import Empty, Queue |
| from threading import Thread |
| from typing import List, Optional |
|
|
| from fastapi import FastAPI |
| from fastapi.encoders import jsonable_encoder |
| from fastapi.responses import JSONResponse, StreamingResponse |
| from pydantic import BaseModel, Field |
|
|
| from pipeline import pipeline as run_pipeline |
| from persona.make_persona import make_persona |
|
|
| app = FastAPI() |
|
|
|
|
| class _ThreadStdoutProxy: |
| def __init__(self, target): |
| self._target = target |
| self._handlers = {} |
| self._lock = threading.RLock() |
| self.encoding = getattr(target, "encoding", "utf-8") |
| self.errors = getattr(target, "errors", None) |
|
|
| def register(self, thread_id: int, handler) -> None: |
| with self._lock: |
| self._handlers[thread_id] = handler |
|
|
| def unregister(self, thread_id: int) -> None: |
| with self._lock: |
| self._handlers.pop(thread_id, None) |
|
|
| def _resolve(self): |
| thread_id = threading.get_ident() |
| with self._lock: |
| return self._handlers.get(thread_id), self._target |
|
|
| def write(self, data): |
| handler, target = self._resolve() |
| if handler: |
| return handler.write(data) |
| return target.write(data) |
|
|
| def flush(self): |
| handler, target = self._resolve() |
| if handler: |
| handler.flush() |
| return target.flush() |
|
|
| def isatty(self): |
| return getattr(self._target, "isatty", lambda: False)() |
|
|
| def fileno(self): |
| return self._target.fileno() |
|
|
| def writable(self): |
| return True |
|
|
| def __getattr__(self, name): |
| return getattr(self._target, name) |
|
|
|
|
| class _QueueingStdoutTee: |
| def __init__(self, target, event_queue: Queue): |
| self._target = target |
| self._event_queue = event_queue |
|
|
| def write(self, data): |
| written = self._target.write(data) |
| if data: |
| self._event_queue.put({"type": "stdout", "message": data}) |
| return written |
|
|
| def flush(self): |
| self._target.flush() |
|
|
|
|
| _stdout_proxy = _ThreadStdoutProxy(sys.stdout) |
| sys.stdout = _stdout_proxy |
|
|
|
|
| class PersonaRequest(BaseModel): |
| info: str |
| stream: bool = True |
|
|
|
|
| PERSONA_STATUS_MESSAGES = [ |
| "์ธ๋ฌผ ์ ๋ณด ์์ง ์ค...", |
| "์น ๊ฒ์์ ํตํด ๋ฐฐ๊ฒฝ ์กฐ์ฌ ์ค...", |
| "๊ธ์ต ์ฌ๊ณ ๋ฐฉ์ ๋ถ์ ์ค...", |
| "๋ฐ์ดํฐ ๋ถ์ ์ ๊ทผ๋ฒ ํ๊ฐ ์ค...", |
| "๋ต๋ณ ์คํ์ผ ํน์ฑ ํ์
์ค...", |
| "ํต์ฌ ํฌ์ ์์น ์ถ์ถ ์ค...", |
| "๋ํ ์ด๋ก ์ ๋ฆฌ ์ค...", |
| "ํ๋ฅด์๋ ํ๋กํ ๊ตฌ์ฑ ์ค...", |
| "์ต์ข
๊ฒ์ฆ ๋ฐ ์ ์ฅ ์ค๋น ์ค...", |
| ] |
|
|
|
|
| def _build_persona_payload(persona) -> dict: |
| return { |
| "type": "result", |
| "name": persona.name, |
| "full_name": persona.full_name, |
| "summary": persona.summary, |
| "financial_mindset": persona.financial_mindset, |
| "data_analysis_approach": persona.data_analysis_approach, |
| "response_style": persona.response_style, |
| "key_principles": persona.key_principles, |
| "famous_quotes": getattr(persona, "famous_quotes", None), |
| } |
|
|
|
|
| @app.post("/persona/") |
| async def create_persona(request: PersonaRequest): |
| info = (request.info or "").strip() |
| stream = request.stream |
|
|
| if not info: |
| return JSONResponse(status_code=400, content={"error": "info ํ๋๊ฐ ๋น์ด ์์ต๋๋ค."}) |
|
|
| if not stream: |
| try: |
| persona = make_persona(info) |
| except Exception as exc: |
| return JSONResponse(status_code=500, content={"error": str(exc)}) |
|
|
| if persona is None: |
| return JSONResponse(status_code=500, content={"error": "ํ๋ฅด์๋ ์์ฑ์ ์คํจํ์ต๋๋ค."}) |
|
|
| return JSONResponse(content=persona.model_dump()) |
|
|
| def event_stream(): |
| event_queue: Queue = Queue() |
|
|
| def status_sender(): |
| import asyncio |
|
|
| async def send_status(): |
| for i, message in enumerate(PERSONA_STATUS_MESSAGES[:-1]): |
| event_queue.put({"type": "status", "message": message}) |
| await asyncio.sleep(8) |
|
|
| |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| loop.run_until_complete(send_status()) |
|
|
| def worker(): |
| thread_id = threading.get_ident() |
| _stdout_proxy.register(thread_id, _QueueingStdoutTee(_stdout_proxy._target, event_queue)) |
| try: |
| |
| status_thread = Thread(target=status_sender, daemon=True) |
| status_thread.start() |
|
|
| persona = make_persona(info) |
|
|
| if persona is None: |
| event_queue.put({"type": "error", "message": "ํ๋ฅด์๋ ์์ฑ์ ์คํจํ์ต๋๋ค."}) |
| else: |
| event_queue.put(_build_persona_payload(persona)) |
| except Exception as exc: |
| event_queue.put({"type": "error", "message": str(exc)}) |
| finally: |
| _stdout_proxy.unregister(thread_id) |
| event_queue.put({"type": "done"}) |
|
|
| yield _sse({"type": "status", "message": "ํ๋ฅด์๋ ์์ฑ ์ค๋น ์ค..."}) |
| Thread(target=worker, daemon=True).start() |
|
|
| done = False |
| while not done: |
| try: |
| event = event_queue.get(timeout=0.2) |
| except Empty: |
| continue |
| yield _sse(jsonable_encoder(event)) |
| if event.get("type") == "done": |
| done = True |
|
|
| headers = { |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no", |
| } |
| return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers) |
|
|
| class QueryRequest(BaseModel): |
| query: str |
| history: List["ChatMessage"] = Field(default_factory=list) |
| stream: bool = True |
| persona_name: Optional[str] = None |
|
|
|
|
| class ChatMessage(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| def _normalize_chat_role(role: str) -> str: |
| role = (role or "").strip().lower() |
| return role |
|
|
|
|
| def _normalize_history_input(history_input): |
| history = [] |
| for message in history_input or []: |
| if isinstance(message, ChatMessage): |
| role = _normalize_chat_role(message.role) |
| content = (message.content or "").strip() |
| elif isinstance(message, dict): |
| role = _normalize_chat_role(message.get("role", "")) |
| content = (message.get("content", "") or "").strip() |
| else: |
| continue |
|
|
| if not role or not content: |
| continue |
| history.append({"role": role, "content": content}) |
| return history |
|
|
|
|
| def _sse(payload: dict) -> str: |
| return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" |
|
|
|
|
| def _build_result_payload(result, stdout: str = "") -> dict: |
| payload = { |
| "type": "result", |
| "query": result.query, |
| "ticker": result.ticker, |
| "analysis_type": result.analysis_type, |
| "data_context": result.data_context, |
| "llm_response": result.llm_response, |
| "timestamp": getattr(result, "timestamp", None), |
| } |
| if stdout: |
| payload["stdout"] = stdout |
| return payload |
|
|
|
|
| @app.post("/analyze/") |
| async def analyze(request: QueryRequest): |
| query = (request.query or "").strip() |
| history = _normalize_history_input(request.history) |
| stream = request.stream |
|
|
| persona_name = (request.persona_name or "").strip() or None |
|
|
| if not query: |
| return JSONResponse(status_code=400, content={"error": "query ํ๋๊ฐ ๋น์ด ์์ต๋๋ค."}) |
|
|
| if not stream: |
| stdout_messages = [] |
|
|
| class _ListStdoutTee: |
| def __init__(self, target): |
| self._target = target |
|
|
| def write(self, data): |
| written = self._target.write(data) |
| if data: |
| stdout_messages.append(data) |
| return written |
|
|
| def flush(self): |
| self._target.flush() |
|
|
| thread_id = threading.get_ident() |
| _stdout_proxy.register(thread_id, _ListStdoutTee(_stdout_proxy._target)) |
| try: |
| result = run_pipeline( |
| query, |
| history=history, |
| persona_name=persona_name, |
| status_callback=None, |
| stream_callback=None, |
| stream=False, |
| ) |
| finally: |
| _stdout_proxy.unregister(thread_id) |
| return JSONResponse( |
| content=jsonable_encoder(_build_result_payload(result, stdout="".join(stdout_messages))) |
| ) |
|
|
| def event_stream(): |
| event_queue: Queue = Queue() |
|
|
| def on_status(message: str): |
| event_queue.put({"type": "status", "message": message}) |
|
|
| def on_delta(delta: str): |
| if stream: |
| event_queue.put({"type": "delta", "delta": delta}) |
|
|
| def worker(): |
| thread_id = threading.get_ident() |
| _stdout_proxy.register(thread_id, _QueueingStdoutTee(_stdout_proxy._target, event_queue)) |
| try: |
| result = run_pipeline( |
| query, |
| history=history, |
| persona_name=persona_name, |
| status_callback=on_status, |
| stream_callback=on_delta if stream else None, |
| stream=stream, |
| ) |
| event_queue.put(_build_result_payload(result)) |
| except Exception as exc: |
| event_queue.put({"type": "error", "message": str(exc)}) |
| finally: |
| _stdout_proxy.unregister(thread_id) |
| event_queue.put({"type": "done"}) |
|
|
| yield _sse({"type": "status", "message": "์์ฒญ ์์ . ๋ถ์ ์ค๋น ์ค..."}) |
| Thread(target=worker, daemon=True).start() |
|
|
| done = False |
| while not done: |
| try: |
| event = event_queue.get(timeout=0.2) |
| except Empty: |
| continue |
| yield _sse(jsonable_encoder(event)) |
| if event.get("type") == "done": |
| done = True |
|
|
| headers = { |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no", |
| } |
| return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers) |
|
|