Corin1998 commited on
Commit
2a333a1
·
verified ·
1 Parent(s): 2dd13f2

Update generators/qa.py

Browse files
Files changed (1) hide show
  1. generators/qa.py +198 -52
generators/qa.py CHANGED
@@ -1,54 +1,200 @@
1
  # generators/qa.py
2
  from __future__ import annotations
3
- from typing import List, Tuple
4
- from irpr.deps import search, generate_chat
5
-
6
- SYS = "あなたは日本語のIR担当です。投資家からの想定質問と模範回答を、根拠に基づいて簡潔に作成します。"
7
-
8
- TPL = """次の抜粋根拠に、投資家向け想定Q&Aを {n} 問作成してくださ
9
- 各問は "Q: ... / A: ..." の2行。Aは2-4文以内。根拠があれば括弧で短く示してください。
10
-
11
- # 抜粋
12
- {context}
13
- """
14
-
15
- def make_qa(query: str, n: int = 30) -> Tuple[List[dict], List[str]]:
16
- hits = search(query, top_k=min(12, n))
17
- links, ctx = [], []
18
- for i, h in enumerate(hits, 1):
19
- src = h.get("source_url") or ""
20
- if src and src not in links: links.append(src)
21
- ctx.append(f"[{i}] {h.get('title') or ''} {src}\n{h['text'][:1000]}")
22
- context = "\n\n".join(ctx) if ctx else "(根拠なし)"
23
-
24
- # OpenAI で生成
25
- qa_list: List[dict] = []
26
- try:
27
- out = generate_chat(
28
- [{"role":"system","content":SYS},
29
- {"role":"user","content":TPL.format(n=n, context=context)}],
30
- max_new_tokens=1500
31
- )
32
- cur = None
33
- for line in out.splitlines():
34
- line = line.strip()
35
- if line.startswith("Q:"):
36
- if cur and cur.get("q") and cur.get("a"):
37
- qa_list.append(cur)
38
- cur = {"q": line[2:].strip(), "a": ""}
39
- elif line.startswith("A:") and cur:
40
- cur["a"] = line[2:].strip()
41
- if cur and cur.get("q") and cur.get("a"):
42
- qa_list.append(cur)
43
- if qa_list:
44
- return qa_list[:n], links
45
- except Exception:
46
- pass
47
-
48
- # フォールバック(抽出)
49
- for h in hits[:n]:
50
- qa_list.append({
51
- "q": f"{(h.get('title') or '決算トピック')}のポイントは?",
52
- "a": (h['text'][:240] + "…").replace("\n"," "),
53
- })
54
- return qa_list[:n], links
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # generators/qa.py
2
  from __future__ import annotations
3
+ from typing import List, Dict, Tuple
4
+ import os, re, textwrap
5
+ from irpr.config import settings
6
+ from irpr.deps import search as rag_search
7
+
8
+ # OpenAI 直呼び(finish_reason 見たい deps.generate_chat は使わな
9
+ def _openai_client():
10
+ from openai import OpenAI
11
+ key = os.environ.get("OPENAI_API_KEY", "").strip()
12
+ if not key:
13
+ raise RuntimeError("OPENAI_API_KEY が未設定です。環境変数に設定してください。")
14
+ return OpenAI(api_key=key)
15
+
16
+ CHAT_MODEL = os.environ.get("OPENAI_CHAT_MODEL", settings.OPENAI_CHAT_MODEL)
17
+
18
+ # ===== ユーティリティ =====
19
+
20
+ def _truncate_chars(s: str, max_chars: int) -> str:
21
+ s = (s or "").strip()
22
+ if len(s) <= max_chars:
23
+ return s
24
+ return s[:max_chars].rstrip() + "…"
25
+
26
+ def _strip_ws(s: str) -> str:
27
+ return (s or "").replace("\u3000", " ").strip()
28
+
29
+ def _dedent(s: str) -> str:
30
+ return textwrap.dedent(s).strip()
31
+
32
+ def _build_sources_block(chunks: List[Dict], per_chunk_max=600) -> Tuple[str, List[str]]:
33
+ """
34
+ 参照用に [1] [2] 形式で短く並べる。各チャンクは最大 per_chunk_max 文字に丸める。
35
+ 返り値: (sources_text, links)
36
+ """
37
+ lines = []
38
+ links: List[str] = []
39
+ for i, c in enumerate(chunks, 1):
40
+ txt = _strip_ws(c.get("text",""))
41
+ txt = re.sub(r"\s+", " ", txt)
42
+ txt = _truncate_chars(txt, per_chunk_max)
43
+ title = _strip_ws(c.get("title") or "")
44
+ url = _strip_ws(c.get("source_url") or "")
45
+ links.append(url)
46
+ if title:
47
+ head = f"[{i}] {title}"
48
+ else:
49
+ head = f"[{i}] 参考 {i}"
50
+ if url:
51
+ head += f" <{url}>"
52
+ lines.append(head + "\n" + txt)
53
+ return "\n\n".join(lines), links
54
+
55
+ def _chat_once(messages: List[Dict], max_tokens=700, temperature=0.2) -> Tuple[str, str]:
56
+ client = _openai_client()
57
+ resp = client.chat.completions.create(
58
+ model=CHAT_MODEL,
59
+ messages=messages,
60
+ temperature=float(temperature),
61
+ max_tokens=int(max_tokens),
62
+ )
63
+ choice = resp.choices[0]
64
+ content = (choice.message.content or "").strip()
65
+ reason = choice.finish_reason or "stop"
66
+ return content, reason
67
+
68
+ def _complete_with_continuations(messages: List[Dict], max_tokens=700, temperature=0.2,
69
+ max_rounds=4, hard_cap_chars=8000) -> str:
70
+ """
71
+ finish_reason が length の間は「続けてください」を自動で投げて結合。
72
+ 念のため全体文字数に上限(hard_cap_chars)をかける。
73
+ """
74
+ out, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
75
+ rounds = 1
76
+ while reason == "length" and rounds < max_rounds and len(out) < hard_cap_chars:
77
+ messages = messages + [
78
+ {"role":"assistant", "content": out[-1200:]}, # 直前の末尾だけ渡す(文脈維持+トークン節約)
79
+ {"role":"user", "content": "続けてください。直前の続きから、重複なく簡潔に出力してください。"}
80
+ ]
81
+ nxt, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
82
+ if not nxt:
83
+ break
84
+ out = (out + "\n" + nxt).strip()
85
+ rounds += 1
86
+ return out[:hard_cap_chars].strip()
87
+
88
+ def _postfix_if_bare(ans: str) -> str:
89
+ """
90
+ 末尾が句読点等で終わっていない場合に句点を補う(CSV で切れて見えづらい対策)。
91
+ """
92
+ s = _strip_ws(ans)
93
+ if not s:
94
+ return s
95
+ if re.search(r"[。.!?!?\)]\s*$", s):
96
+ return s
97
+ # 参照記号 [1] 等で終わる場合はOK
98
+ if re.search(r"\[\d+\]\s*$", s):
99
+ return s
100
+ return s + "。"
101
+
102
+ # ===== 質問候補の生成 =====
103
+
104
+ QUESTION_GUIDE = _dedent("""
105
+ あなたは日本の上場企業のIR資料に詳しいアナリストです。
106
+ 以下のトピックについて、投資家・メディアから想定される質問を日本語で生成してください。
107
+ - 決算ハイライト、通期見通し、セグメント別動向、費用/粗利、キャッシュフロー、投資計画、配当/自社株買い、ESG、リスク、質疑応答で深掘りされやすい論点など
108
+ 出力形式は、各行に1問だけのプレーンテキスト。番号は付けません。
109
+ 質問は簡潔に(1行80文字以内)、具体的に、事実確認ではなく説明を引き出す聞き方にしてください。
110
+ """)
111
+
112
+ def _propose_questions(query: str, n: int) -> List[str]:
113
+ base_msgs = [
114
+ {"role":"system", "content": "あなたは有能なIRアナリストです。"},
115
+ {"role":"user", "content": QUESTION_GUIDE + f"\n\n対象トピック:\n{query}\n\n必要な件数: {n}問"}
116
+ ]
117
+ text = _complete_with_continuations(base_msgs, max_tokens=700, temperature=0.2, max_rounds=2)
118
+ qs = [q.strip(" ・-—\t") for q in text.splitlines() if q.strip()]
119
+ # 行頭の番号や記号を除去
120
+ cleaned: List[str] = []
121
+ for q in qs:
122
+ q = re.sub(r"^\d+[\).、]\s*", "", q)
123
+ q = re.sub(r"^[・\-—]\s*", "", q)
124
+ if q and q not in cleaned:
125
+ cleaned.append(q)
126
+ return cleaned[:n] if len(cleaned) >= n else cleaned
127
+
128
+ # ===== 回答生成 =====
129
+
130
+ ANSWER_SYS = _dedent("""
131
+ あなたは上場企業のIR担当者のつもりで、与えられた「資料抜粋(Sources)」だけを根拠に、日本語で誠実に回答します。
132
+ - 回答は5〜10文(または箇条書き5〜8点)で、具体的な数値・期間・要因を入れてください。
133
+ - 資料にない推測はしません。情報が不足していれば「判明分」と「不明点」を分けて述べます。
134
+ - 引用は [1], [2] 形式で付けます(対応するSource番号)。同一文末に複数可。
135
+ - 最後に1文で簡潔に要約してください。
136
+ """).strip()
137
+
138
+ def _answer_one(question: str, top_k=8) -> Tuple[str, List[str]]:
139
+ # 関連チャンク検索
140
+ chunks = rag_search(question, top_k=top_k)
141
+ # 取りすぎると長くなるので上位を採用
142
+ chunks = (chunks or [])[:top_k]
143
+ sources_text, links = _build_sources_block(chunks, per_chunk_max=500)
144
+
145
+ # コンテキストが少ない/空なら、ユーザー質問に基づく一般的テンプレで埋める(ただし「一般論」明記)
146
+ if not sources_text.strip():
147
+ sources_text = "(該当資料無し)"
148
+
149
+ # メッセージ組み立て
150
+ prompt = _dedent(f"""
151
+ <質問>
152
+ {question}
153
+
154
+ <資料抜粋(Sources)>
155
+ {sources_text}
156
+
157
+ <指示>
158
+ - 上記 Sources だけを根拠に回答。根拠となった箇所の番号を [n] で明示。
159
+ - 200〜600字程度を目安に、冗長な導入は避け、結論から書く。
160
+ - 数字や固有名詞は元資料に合わせる。
161
+ - 不足があれば不足点を最後に1行で注記。
162
+ """)
163
+ messages = [
164
+ {"role":"system", "content": ANSWER_SYS},
165
+ {"role":"user", "content": prompt}
166
+ ]
167
+ # 途中打ち切りを自動で継続
168
+ ans = _complete_with_continuations(messages, max_tokens=900, temperature=0.2, max_rounds=3, hard_cap_chars=4000)
169
+ ans = _postfix_if_bare(ans)
170
+ return ans, links
171
+
172
+ # ===== エクスポートAPI =====
173
+
174
+ def make_qa(query: str, n: int = 30) -> Tuple[List[Dict], List[str]]:
175
+ """
176
+ 返り値:
177
+ qa_list: [{"q": str, "a": str}, ...]
178
+ links: 重複排除したURL一覧(参考リンク用)
179
+ """
180
+ # まず質問候補を出す
181
+ qs = _propose_questions(query, n)
182
+ if not qs:
183
+ # 最低限のフォールバック
184
+ qs = [f"{_strip_ws(query)}の四半期業績の増減要因は?",
185
+ "通期見通しの前提(為替、コスト、数量)は?",
186
+ "セグメント別の業績動向と主要KPIの見通しは?",
187
+ "資本政策(配当方針/自社株買い)とその根拠は?",
188
+ "主なリスクと対応策は?"][:n]
189
+
190
+ qa_list: List[Dict] = []
191
+ all_links: List[str] = []
192
+ for q in qs[:n]:
193
+ a, links = _answer_one(q, top_k=8)
194
+ qa_list.append({"q": q, "a": a})
195
+ all_links.extend(links or [])
196
+
197
+ # 重複除去・順序維持
198
+ uniq_links = list(dict.fromkeys([u for u in all_links if u]))
199
+
200
+ return qa_list, uniq_links