Corin1998's picture
Update generators/qa.py
2a333a1 verified
# generators/qa.py
from __future__ import annotations
from typing import List, Dict, Tuple
import os, re, textwrap
from irpr.config import settings
from irpr.deps import search as rag_search
# OpenAI 直呼び(finish_reason を見たいので deps.generate_chat は使わない)
def _openai_client():
from openai import OpenAI
key = os.environ.get("OPENAI_API_KEY", "").strip()
if not key:
raise RuntimeError("OPENAI_API_KEY が未設定です。環境変数に設定してください。")
return OpenAI(api_key=key)
CHAT_MODEL = os.environ.get("OPENAI_CHAT_MODEL", settings.OPENAI_CHAT_MODEL)
# ===== ユーティリティ =====
def _truncate_chars(s: str, max_chars: int) -> str:
s = (s or "").strip()
if len(s) <= max_chars:
return s
return s[:max_chars].rstrip() + "…"
def _strip_ws(s: str) -> str:
return (s or "").replace("\u3000", " ").strip()
def _dedent(s: str) -> str:
return textwrap.dedent(s).strip()
def _build_sources_block(chunks: List[Dict], per_chunk_max=600) -> Tuple[str, List[str]]:
"""
参照用に [1] [2] 形式で短く並べる。各チャンクは最大 per_chunk_max 文字に丸める。
返り値: (sources_text, links)
"""
lines = []
links: List[str] = []
for i, c in enumerate(chunks, 1):
txt = _strip_ws(c.get("text",""))
txt = re.sub(r"\s+", " ", txt)
txt = _truncate_chars(txt, per_chunk_max)
title = _strip_ws(c.get("title") or "")
url = _strip_ws(c.get("source_url") or "")
links.append(url)
if title:
head = f"[{i}] {title}"
else:
head = f"[{i}] 参考 {i}"
if url:
head += f" <{url}>"
lines.append(head + "\n" + txt)
return "\n\n".join(lines), links
def _chat_once(messages: List[Dict], max_tokens=700, temperature=0.2) -> Tuple[str, str]:
client = _openai_client()
resp = client.chat.completions.create(
model=CHAT_MODEL,
messages=messages,
temperature=float(temperature),
max_tokens=int(max_tokens),
)
choice = resp.choices[0]
content = (choice.message.content or "").strip()
reason = choice.finish_reason or "stop"
return content, reason
def _complete_with_continuations(messages: List[Dict], max_tokens=700, temperature=0.2,
max_rounds=4, hard_cap_chars=8000) -> str:
"""
finish_reason が length の間は「続けてください」を自動で投げて結合。
念のため全体文字数に上限(hard_cap_chars)をかける。
"""
out, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
rounds = 1
while reason == "length" and rounds < max_rounds and len(out) < hard_cap_chars:
messages = messages + [
{"role":"assistant", "content": out[-1200:]}, # 直前の末尾だけ渡す(文脈維持+トークン節約)
{"role":"user", "content": "続けてください。直前の続きから、重複なく簡潔に出力してください。"}
]
nxt, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
if not nxt:
break
out = (out + "\n" + nxt).strip()
rounds += 1
return out[:hard_cap_chars].strip()
def _postfix_if_bare(ans: str) -> str:
"""
末尾が句読点等で終わっていない場合に句点を補う(CSV で切れて見えづらい対策)。
"""
s = _strip_ws(ans)
if not s:
return s
if re.search(r"[。.!?!?\)]\s*$", s):
return s
# 参照記号 [1] 等で終わる場合はOK
if re.search(r"\[\d+\]\s*$", s):
return s
return s + "。"
# ===== 質問候補の生成 =====
QUESTION_GUIDE = _dedent("""
あなたは日本の上場企業のIR資料に詳しいアナリストです。
以下のトピックについて、投資家・メディアから想定される質問を日本語で生成してください。
- 決算ハイライト、通期見通し、セグメント別動向、費用/粗利、キャッシュフロー、投資計画、配当/自社株買い、ESG、リスク、質疑応答で深掘りされやすい論点など
出力形式は、各行に1問だけのプレーンテキスト。番号は付けません。
質問は簡潔に(1行80文字以内)、具体的に、事実確認ではなく説明を引き出す聞き方にしてください。
""")
def _propose_questions(query: str, n: int) -> List[str]:
base_msgs = [
{"role":"system", "content": "あなたは有能なIRアナリストです。"},
{"role":"user", "content": QUESTION_GUIDE + f"\n\n対象トピック:\n{query}\n\n必要な件数: {n}問"}
]
text = _complete_with_continuations(base_msgs, max_tokens=700, temperature=0.2, max_rounds=2)
qs = [q.strip(" ・-—\t") for q in text.splitlines() if q.strip()]
# 行頭の番号や記号を除去
cleaned: List[str] = []
for q in qs:
q = re.sub(r"^\d+[\).、]\s*", "", q)
q = re.sub(r"^[・\-—]\s*", "", q)
if q and q not in cleaned:
cleaned.append(q)
return cleaned[:n] if len(cleaned) >= n else cleaned
# ===== 回答生成 =====
ANSWER_SYS = _dedent("""
あなたは上場企業のIR担当者のつもりで、与えられた「資料抜粋(Sources)」だけを根拠に、日本語で誠実に回答します。
- 回答は5〜10文(または箇条書き5〜8点)で、具体的な数値・期間・要因を入れてください。
- 資料にない推測はしません。情報が不足していれば「判明分」と「不明点」を分けて述べます。
- 引用は [1], [2] 形式で付けます(対応するSource番号)。同一文末に複数可。
- 最後に1文で簡潔に要約してください。
""").strip()
def _answer_one(question: str, top_k=8) -> Tuple[str, List[str]]:
# 関連チャンク検索
chunks = rag_search(question, top_k=top_k)
# 取りすぎると長くなるので上位を採用
chunks = (chunks or [])[:top_k]
sources_text, links = _build_sources_block(chunks, per_chunk_max=500)
# コンテキストが少ない/空なら、ユーザー質問に基づく一般的テンプレで埋める(ただし「一般論」明記)
if not sources_text.strip():
sources_text = "(該当資料無し)"
# メッセージ組み立て
prompt = _dedent(f"""
<質問>
{question}
<資料抜粋(Sources)>
{sources_text}
<指示>
- 上記 Sources だけを根拠に回答。根拠となった箇所の番号を [n] で明示。
- 200〜600字程度を目安に、冗長な導入は避け、結論から書く。
- 数字や固有名詞は元資料に合わせる。
- 不足があれば不足点を最後に1行で注記。
""")
messages = [
{"role":"system", "content": ANSWER_SYS},
{"role":"user", "content": prompt}
]
# 途中打ち切りを自動で継続
ans = _complete_with_continuations(messages, max_tokens=900, temperature=0.2, max_rounds=3, hard_cap_chars=4000)
ans = _postfix_if_bare(ans)
return ans, links
# ===== エクスポートAPI =====
def make_qa(query: str, n: int = 30) -> Tuple[List[Dict], List[str]]:
"""
返り値:
qa_list: [{"q": str, "a": str}, ...]
links: 重複排除したURL一覧(参考リンク用)
"""
# まず質問候補を出す
qs = _propose_questions(query, n)
if not qs:
# 最低限のフォールバック
qs = [f"{_strip_ws(query)}の四半期業績の増減要因は?",
"通期見通しの前提(為替、コスト、数量)は?",
"セグメント別の業績動向と主要KPIの見通しは?",
"資本政策(配当方針/自社株買い)とその根拠は?",
"主なリスクと対応策は?"][:n]
qa_list: List[Dict] = []
all_links: List[str] = []
for q in qs[:n]:
a, links = _answer_one(q, top_k=8)
qa_list.append({"q": q, "a": a})
all_links.extend(links or [])
# 重複除去・順序維持
uniq_links = list(dict.fromkeys([u for u in all_links if u]))
return qa_list, uniq_links