Doleeee commited on
Commit
2a0c436
·
1 Parent(s): 3af0512

api_server analyze에서 query로 history 함께 입력가능하도록 수정 #31

Browse files
Files changed (3) hide show
  1. api_server.py +49 -3
  2. llm/generator.py +55 -13
  3. pipeline.py +29 -4
api_server.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
  import threading
4
  from queue import Empty, Queue
5
  from threading import Thread
6
- from typing import Optional
7
 
8
  from fastapi import FastAPI
9
  from fastapi.encoders import jsonable_encoder
@@ -189,11 +189,52 @@ async def create_persona(request: PersonaRequest):
189
  return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
190
 
191
  class QueryRequest(BaseModel):
192
- query: str
193
  stream: bool = True
194
  persona_name: Optional[str] = None
195
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  def _sse(payload: dict) -> str:
198
  return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
199
 
@@ -215,11 +256,14 @@ def _build_result_payload(result, stdout: str = "") -> dict:
215
 
216
  @app.post("/analyze/")
217
  async def analyze(request: QueryRequest):
218
- query = (request.query or "").strip()
219
  stream = request.stream
220
 
221
  persona_name = (request.persona_name or "").strip() or None
222
 
 
 
 
223
  if not stream:
224
  stdout_messages = []
225
 
@@ -241,6 +285,7 @@ async def analyze(request: QueryRequest):
241
  try:
242
  result = run_pipeline(
243
  query,
 
244
  persona_name=persona_name,
245
  status_callback=None,
246
  stream_callback=None,
@@ -268,6 +313,7 @@ async def analyze(request: QueryRequest):
268
  try:
269
  result = run_pipeline(
270
  query,
 
271
  persona_name=persona_name,
272
  status_callback=on_status,
273
  stream_callback=on_delta if stream else None,
 
3
  import threading
4
  from queue import Empty, Queue
5
  from threading import Thread
6
+ from typing import List, Optional, Union
7
 
8
  from fastapi import FastAPI
9
  from fastapi.encoders import jsonable_encoder
 
189
  return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
190
 
191
  class QueryRequest(BaseModel):
192
+ query: Union[str, List["ChatMessage"]]
193
  stream: bool = True
194
  persona_name: Optional[str] = None
195
 
196
 
197
+ class ChatMessage(BaseModel):
198
+ role: str
199
+ content: str
200
+
201
+
202
+ def _normalize_chat_role(role: str) -> str:
203
+ role = (role or "").strip().lower()
204
+ return role
205
+
206
+
207
+ def _normalize_query_input(query_input):
208
+ if isinstance(query_input, str):
209
+ return query_input.strip(), []
210
+
211
+ if not isinstance(query_input, list):
212
+ return "", []
213
+
214
+ conversation = []
215
+ for message in query_input:
216
+ if isinstance(message, ChatMessage):
217
+ role = _normalize_chat_role(message.role)
218
+ content = (message.content or "").strip()
219
+ elif isinstance(message, dict):
220
+ role = _normalize_chat_role(message.get("role", ""))
221
+ content = (message.get("content", "") or "").strip()
222
+ else:
223
+ continue
224
+
225
+ if not role or not content:
226
+ continue
227
+ conversation.append({"role": role, "content": content})
228
+
229
+ current_user_query = ""
230
+ for message in reversed(conversation):
231
+ if message["role"] == "user":
232
+ current_user_query = message["content"]
233
+ break
234
+
235
+ return current_user_query, conversation
236
+
237
+
238
  def _sse(payload: dict) -> str:
239
  return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
240
 
 
256
 
257
  @app.post("/analyze/")
258
  async def analyze(request: QueryRequest):
259
+ query, conversation = _normalize_query_input(request.query)
260
  stream = request.stream
261
 
262
  persona_name = (request.persona_name or "").strip() or None
263
 
264
+ if not query:
265
+ return JSONResponse(status_code=400, content={"error": "query 필드가 비어 있습니다."})
266
+
267
  if not stream:
268
  stdout_messages = []
269
 
 
285
  try:
286
  result = run_pipeline(
287
  query,
288
+ conversation=conversation,
289
  persona_name=persona_name,
290
  status_callback=None,
291
  stream_callback=None,
 
313
  try:
314
  result = run_pipeline(
315
  query,
316
+ conversation=conversation,
317
  persona_name=persona_name,
318
  status_callback=on_status,
319
  stream_callback=on_delta if stream else None,
llm/generator.py CHANGED
@@ -24,6 +24,36 @@ def extract_response_text(resp):
24
  return "\n".join(texts)
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def generate_search_keywords(client, user_query, intent):
28
  """LLM을 통해 구글 뉴스 검색어 리스트 생성"""
29
  language = intent.get("language", "ko")
@@ -110,7 +140,7 @@ def generate_persona(client, user_query):
110
  return None
111
 
112
 
113
- def build_full_prompt(user_query, context, intent, persona=None):
114
  analysis_type = intent.get("analysis_type", "general")
115
  language = intent.get("language", "ko")
116
  system_prompt = SYSTEM_PROMPTS.get(analysis_type, SYSTEM_PROMPTS["general"])
@@ -119,9 +149,6 @@ def build_full_prompt(user_query, context, intent, persona=None):
119
 
120
  if persona:
121
  system_prompt += f"""
122
- [사용자 질의]
123
- {user_query}
124
-
125
  [선택된 페르소나]
126
  이름: {persona.name}
127
  요약: {persona.summary}
@@ -133,26 +160,41 @@ def build_full_prompt(user_query, context, intent, persona=None):
133
  if persona.famous_quotes:
134
  system_prompt += f"\n대표 어록: {' / '.join(persona.famous_quotes)}"
135
 
136
- full_prompt = f"""{system_prompt}
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  [수집된 시장 데이터]
139
- {context}"""
140
- return full_prompt
 
 
141
 
142
- def generate_analysis(client, user_query, context, intent, persona=None):
143
- full_prompt = build_full_prompt(user_query, context, intent, persona)
144
  LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME')
145
  resp = client.responses.create(
146
  model=LLM_MODEL_NAME,
147
- input=full_prompt,
148
  )
149
 
150
  result = extract_response_text(resp)
151
  return result or "(분석 결과를 가져오지 못했습니다)"
152
 
153
- def generate_analysis_stream(client, user_query, context, intent, persona=None):
154
 
155
- full_prompt = build_full_prompt(user_query, context, intent, persona)
156
  llm_model_name = os.environ.get("LLM_MODEL_NAME")
157
  print(f"[⑤] LLM 분석 스트리밍 생성 중 (Responses API, 모델: {llm_model_name})...")
158
 
@@ -166,7 +208,7 @@ def generate_analysis_stream(client, user_query, context, intent, persona=None):
166
  # SDK에 따라 stream API 형태가 다를 수 있어 create(stream=True) 기준으로 처리
167
  stream = client.responses.create(
168
  model=llm_model_name,
169
- input=full_prompt,
170
  stream=True,
171
  )
172
 
 
24
  return "\n".join(texts)
25
 
26
 
27
+ def _normalize_history_role(role):
28
+ role = (role or "").strip().lower()
29
+ if role in {"user", "assistant", "system"}:
30
+ return role
31
+ return None
32
+
33
+
34
+ def _split_conversation_history(conversation, current_user_query):
35
+ if not conversation:
36
+ return []
37
+
38
+ last_user_index = -1
39
+ for i in range(len(conversation) - 1, -1, -1):
40
+ role = (conversation[i].get("role") or "").strip().lower()
41
+ content = (conversation[i].get("content") or "").strip()
42
+ if role == "user" and content == current_user_query:
43
+ last_user_index = i
44
+ break
45
+
46
+ history = conversation[:last_user_index] if last_user_index >= 0 else conversation
47
+ normalized_history = []
48
+ for message in history:
49
+ role = _normalize_history_role(message.get("role"))
50
+ content = (message.get("content") or "").strip()
51
+ if not role or not content:
52
+ continue
53
+ normalized_history.append({"role": role, "content": content})
54
+ return normalized_history
55
+
56
+
57
  def generate_search_keywords(client, user_query, intent):
58
  """LLM을 통해 구글 뉴스 검색어 리스트 생성"""
59
  language = intent.get("language", "ko")
 
140
  return None
141
 
142
 
143
+ def build_system_prompt(intent, persona=None):
144
  analysis_type = intent.get("analysis_type", "general")
145
  language = intent.get("language", "ko")
146
  system_prompt = SYSTEM_PROMPTS.get(analysis_type, SYSTEM_PROMPTS["general"])
 
149
 
150
  if persona:
151
  system_prompt += f"""
 
 
 
152
  [선택된 페르소나]
153
  이름: {persona.name}
154
  요약: {persona.summary}
 
160
  if persona.famous_quotes:
161
  system_prompt += f"\n대표 어록: {' / '.join(persona.famous_quotes)}"
162
 
163
+ return system_prompt
164
+
165
+
166
+ def build_analysis_input(user_query, context, intent, persona=None, conversation=None):
167
+ system_prompt = build_system_prompt(intent, persona=persona)
168
+ history = _split_conversation_history(conversation, user_query)
169
+
170
+ input_messages = [{"role": "system", "content": system_prompt}]
171
+ input_messages.extend(history)
172
+ input_messages.append(
173
+ {
174
+ "role": "user",
175
+ "content": f"""[현재 사용자 질의]
176
+ {user_query}
177
 
178
  [수집된 시장 데이터]
179
+ {context}""",
180
+ }
181
+ )
182
+ return input_messages
183
 
184
+ def generate_analysis(client, user_query, context, intent, persona=None, conversation=None):
185
+ analysis_input = build_analysis_input(user_query, context, intent, persona, conversation=conversation)
186
  LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME')
187
  resp = client.responses.create(
188
  model=LLM_MODEL_NAME,
189
+ input=analysis_input,
190
  )
191
 
192
  result = extract_response_text(resp)
193
  return result or "(분석 결과를 가져오지 못했습니다)"
194
 
195
+ def generate_analysis_stream(client, user_query, context, intent, persona=None, conversation=None):
196
 
197
+ analysis_input = build_analysis_input(user_query, context, intent, persona, conversation=conversation)
198
  llm_model_name = os.environ.get("LLM_MODEL_NAME")
199
  print(f"[⑤] LLM 분석 스트리밍 생성 중 (Responses API, 모델: {llm_model_name})...")
200
 
 
208
  # SDK에 따라 stream API 형태가 다를 수 있어 create(stream=True) 기준으로 처리
209
  stream = client.responses.create(
210
  model=llm_model_name,
211
+ input=analysis_input,
212
  stream=True,
213
  )
214
 
pipeline.py CHANGED
@@ -69,7 +69,7 @@ def save_result_jsonl(result):
69
  with open(file_name, "a", encoding="utf-8") as f:
70
  f.write(json.dumps(ordered_data, ensure_ascii=False) + "\n")
71
 
72
- def pipeline(query, persona_name=None, status_callback=None, stream_callback=None, stream=True):
73
  """
74
  파이프라인:
75
  ① 인텐트 파싱 (Chat Completions + Function Calling)
@@ -120,7 +120,14 @@ def pipeline(query, persona_name=None, status_callback=None, stream_callback=Non
120
  if stream:
121
  chunks = []
122
  print("[⑤] 스트리밍 응답 수신 중...")
123
- for delta in generate_analysis_stream(client, query, context, intent, persona=persona):
 
 
 
 
 
 
 
124
  if not delta:
125
  continue
126
  chunks.append(delta)
@@ -130,7 +137,14 @@ def pipeline(query, persona_name=None, status_callback=None, stream_callback=Non
130
  response = "".join(chunks).strip() or "(분석 결과를 가져오지 못했습니다)"
131
  else:
132
  print("[⑤] 단일 응답 생성 중...")
133
- response = generate_analysis(client, query, context, intent, persona=persona)
 
 
 
 
 
 
 
134
  if response:
135
  emit_delta(response)
136
 
@@ -201,13 +215,24 @@ def main():
201
  except ValueError:
202
  print("잘못된 입력입니다. 기본 모드로 진행합니다.")
203
 
 
 
 
 
204
  while True:
205
  text = input("\n질문 > ").strip()
206
  if text.lower() in ("exit", "quit", "종료"):
207
  break
 
 
 
 
208
  if not text:
209
  continue
210
- result = pipeline(text, persona_name=persona_name)
 
 
 
211
  print_result(result)
212
 
213
 
 
69
  with open(file_name, "a", encoding="utf-8") as f:
70
  f.write(json.dumps(ordered_data, ensure_ascii=False) + "\n")
71
 
72
+ def pipeline(query, conversation=None, persona_name=None, status_callback=None, stream_callback=None, stream=True):
73
  """
74
  파이프라인:
75
  ① 인텐트 파싱 (Chat Completions + Function Calling)
 
120
  if stream:
121
  chunks = []
122
  print("[⑤] 스트리밍 응답 수신 중...")
123
+ for delta in generate_analysis_stream(
124
+ client,
125
+ query,
126
+ context,
127
+ intent,
128
+ persona=persona,
129
+ conversation=conversation,
130
+ ):
131
  if not delta:
132
  continue
133
  chunks.append(delta)
 
137
  response = "".join(chunks).strip() or "(분석 결과를 가져오지 못했습니다)"
138
  else:
139
  print("[⑤] 단일 응답 생성 중...")
140
+ response = generate_analysis(
141
+ client,
142
+ query,
143
+ context,
144
+ intent,
145
+ persona=persona,
146
+ conversation=conversation,
147
+ )
148
  if response:
149
  emit_delta(response)
150
 
 
215
  except ValueError:
216
  print("잘못된 입력입니다. 기본 모드로 진행합니다.")
217
 
218
+ conversation = []
219
+ print("\n멀티턴 대화 모드입니다. 이전 질문/답변이 다음 분석에 함께 반영됩니다.")
220
+ print("대화 초기화: reset 또는 clear | 종료: exit, quit, 종료")
221
+
222
  while True:
223
  text = input("\n질문 > ").strip()
224
  if text.lower() in ("exit", "quit", "종료"):
225
  break
226
+ if text.lower() in ("reset", "clear"):
227
+ conversation = []
228
+ print("대화 히스토리를 초기화했습니다.")
229
+ continue
230
  if not text:
231
  continue
232
+
233
+ current_conversation = conversation + [{"role": "user", "content": text}]
234
+ result = pipeline(text, conversation=current_conversation, persona_name=persona_name)
235
+ conversation = current_conversation + [{"role": "assistant", "content": result.llm_response}]
236
  print_result(result)
237
 
238