Spaces:
Sleeping
Sleeping
File size: 8,349 Bytes
c1a58c7 2a333a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# generators/qa.py
from __future__ import annotations
from typing import List, Dict, Tuple
import os, re, textwrap
from irpr.config import settings
from irpr.deps import search as rag_search
# OpenAI 直呼び(finish_reason を見たいので deps.generate_chat は使わない)
def _openai_client():
from openai import OpenAI
key = os.environ.get("OPENAI_API_KEY", "").strip()
if not key:
raise RuntimeError("OPENAI_API_KEY が未設定です。環境変数に設定してください。")
return OpenAI(api_key=key)
CHAT_MODEL = os.environ.get("OPENAI_CHAT_MODEL", settings.OPENAI_CHAT_MODEL)
# ===== ユーティリティ =====
def _truncate_chars(s: str, max_chars: int) -> str:
s = (s or "").strip()
if len(s) <= max_chars:
return s
return s[:max_chars].rstrip() + "…"
def _strip_ws(s: str) -> str:
return (s or "").replace("\u3000", " ").strip()
def _dedent(s: str) -> str:
return textwrap.dedent(s).strip()
def _build_sources_block(chunks: List[Dict], per_chunk_max=600) -> Tuple[str, List[str]]:
"""
参照用に [1] [2] 形式で短く並べる。各チャンクは最大 per_chunk_max 文字に丸める。
返り値: (sources_text, links)
"""
lines = []
links: List[str] = []
for i, c in enumerate(chunks, 1):
txt = _strip_ws(c.get("text",""))
txt = re.sub(r"\s+", " ", txt)
txt = _truncate_chars(txt, per_chunk_max)
title = _strip_ws(c.get("title") or "")
url = _strip_ws(c.get("source_url") or "")
links.append(url)
if title:
head = f"[{i}] {title}"
else:
head = f"[{i}] 参考 {i}"
if url:
head += f" <{url}>"
lines.append(head + "\n" + txt)
return "\n\n".join(lines), links
def _chat_once(messages: List[Dict], max_tokens=700, temperature=0.2) -> Tuple[str, str]:
client = _openai_client()
resp = client.chat.completions.create(
model=CHAT_MODEL,
messages=messages,
temperature=float(temperature),
max_tokens=int(max_tokens),
)
choice = resp.choices[0]
content = (choice.message.content or "").strip()
reason = choice.finish_reason or "stop"
return content, reason
def _complete_with_continuations(messages: List[Dict], max_tokens=700, temperature=0.2,
max_rounds=4, hard_cap_chars=8000) -> str:
"""
finish_reason が length の間は「続けてください」を自動で投げて結合。
念のため全体文字数に上限(hard_cap_chars)をかける。
"""
out, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
rounds = 1
while reason == "length" and rounds < max_rounds and len(out) < hard_cap_chars:
messages = messages + [
{"role":"assistant", "content": out[-1200:]}, # 直前の末尾だけ渡す(文脈維持+トークン節約)
{"role":"user", "content": "続けてください。直前の続きから、重複なく簡潔に出力してください。"}
]
nxt, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
if not nxt:
break
out = (out + "\n" + nxt).strip()
rounds += 1
return out[:hard_cap_chars].strip()
def _postfix_if_bare(ans: str) -> str:
"""
末尾が句読点等で終わっていない場合に句点を補う(CSV で切れて見えづらい対策)。
"""
s = _strip_ws(ans)
if not s:
return s
if re.search(r"[。.!?!?\)]\s*$", s):
return s
# 参照記号 [1] 等で終わる場合はOK
if re.search(r"\[\d+\]\s*$", s):
return s
return s + "。"
# ===== 質問候補の生成 =====
QUESTION_GUIDE = _dedent("""
あなたは日本の上場企業のIR資料に詳しいアナリストです。
以下のトピックについて、投資家・メディアから想定される質問を日本語で生成してください。
- 決算ハイライト、通期見通し、セグメント別動向、費用/粗利、キャッシュフロー、投資計画、配当/自社株買い、ESG、リスク、質疑応答で深掘りされやすい論点など
出力形式は、各行に1問だけのプレーンテキスト。番号は付けません。
質問は簡潔に(1行80文字以内)、具体的に、事実確認ではなく説明を引き出す聞き方にしてください。
""")
def _propose_questions(query: str, n: int) -> List[str]:
base_msgs = [
{"role":"system", "content": "あなたは有能なIRアナリストです。"},
{"role":"user", "content": QUESTION_GUIDE + f"\n\n対象トピック:\n{query}\n\n必要な件数: {n}問"}
]
text = _complete_with_continuations(base_msgs, max_tokens=700, temperature=0.2, max_rounds=2)
qs = [q.strip(" ・-—\t") for q in text.splitlines() if q.strip()]
# 行頭の番号や記号を除去
cleaned: List[str] = []
for q in qs:
q = re.sub(r"^\d+[\).、]\s*", "", q)
q = re.sub(r"^[・\-—]\s*", "", q)
if q and q not in cleaned:
cleaned.append(q)
return cleaned[:n] if len(cleaned) >= n else cleaned
# ===== 回答生成 =====
ANSWER_SYS = _dedent("""
あなたは上場企業のIR担当者のつもりで、与えられた「資料抜粋(Sources)」だけを根拠に、日本語で誠実に回答します。
- 回答は5〜10文(または箇条書き5〜8点)で、具体的な数値・期間・要因を入れてください。
- 資料にない推測はしません。情報が不足していれば「判明分」と「不明点」を分けて述べます。
- 引用は [1], [2] 形式で付けます(対応するSource番号)。同一文末に複数可。
- 最後に1文で簡潔に要約してください。
""").strip()
def _answer_one(question: str, top_k=8) -> Tuple[str, List[str]]:
# 関連チャンク検索
chunks = rag_search(question, top_k=top_k)
# 取りすぎると長くなるので上位を採用
chunks = (chunks or [])[:top_k]
sources_text, links = _build_sources_block(chunks, per_chunk_max=500)
# コンテキストが少ない/空なら、ユーザー質問に基づく一般的テンプレで埋める(ただし「一般論」明記)
if not sources_text.strip():
sources_text = "(該当資料無し)"
# メッセージ組み立て
prompt = _dedent(f"""
<質問>
{question}
<資料抜粋(Sources)>
{sources_text}
<指示>
- 上記 Sources だけを根拠に回答。根拠となった箇所の番号を [n] で明示。
- 200〜600字程度を目安に、冗長な導入は避け、結論から書く。
- 数字や固有名詞は元資料に合わせる。
- 不足があれば不足点を最後に1行で注記。
""")
messages = [
{"role":"system", "content": ANSWER_SYS},
{"role":"user", "content": prompt}
]
# 途中打ち切りを自動で継続
ans = _complete_with_continuations(messages, max_tokens=900, temperature=0.2, max_rounds=3, hard_cap_chars=4000)
ans = _postfix_if_bare(ans)
return ans, links
# ===== エクスポートAPI =====
def make_qa(query: str, n: int = 30) -> Tuple[List[Dict], List[str]]:
"""
返り値:
qa_list: [{"q": str, "a": str}, ...]
links: 重複排除したURL一覧(参考リンク用)
"""
# まず質問候補を出す
qs = _propose_questions(query, n)
if not qs:
# 最低限のフォールバック
qs = [f"{_strip_ws(query)}の四半期業績の増減要因は?",
"通期見通しの前提(為替、コスト、数量)は?",
"セグメント別の業績動向と主要KPIの見通しは?",
"資本政策(配当方針/自社株買い)とその根拠は?",
"主なリスクと対応策は?"][:n]
qa_list: List[Dict] = []
all_links: List[str] = []
for q in qs[:n]:
a, links = _answer_one(q, top_k=8)
qa_list.append({"q": q, "a": a})
all_links.extend(links or [])
# 重複除去・順序維持
uniq_links = list(dict.fromkeys([u for u in all_links if u]))
return qa_list, uniq_links
|