Spaces:
Sleeping
Sleeping
| # generators/qa.py | |
| from __future__ import annotations | |
| from typing import List, Dict, Tuple | |
| import os, re, textwrap | |
| from irpr.config import settings | |
| from irpr.deps import search as rag_search | |
| # OpenAI 直呼び(finish_reason を見たいので deps.generate_chat は使わない) | |
| def _openai_client(): | |
| from openai import OpenAI | |
| key = os.environ.get("OPENAI_API_KEY", "").strip() | |
| if not key: | |
| raise RuntimeError("OPENAI_API_KEY が未設定です。環境変数に設定してください。") | |
| return OpenAI(api_key=key) | |
| CHAT_MODEL = os.environ.get("OPENAI_CHAT_MODEL", settings.OPENAI_CHAT_MODEL) | |
| # ===== ユーティリティ ===== | |
| def _truncate_chars(s: str, max_chars: int) -> str: | |
| s = (s or "").strip() | |
| if len(s) <= max_chars: | |
| return s | |
| return s[:max_chars].rstrip() + "…" | |
| def _strip_ws(s: str) -> str: | |
| return (s or "").replace("\u3000", " ").strip() | |
| def _dedent(s: str) -> str: | |
| return textwrap.dedent(s).strip() | |
| def _build_sources_block(chunks: List[Dict], per_chunk_max=600) -> Tuple[str, List[str]]: | |
| """ | |
| 参照用に [1] [2] 形式で短く並べる。各チャンクは最大 per_chunk_max 文字に丸める。 | |
| 返り値: (sources_text, links) | |
| """ | |
| lines = [] | |
| links: List[str] = [] | |
| for i, c in enumerate(chunks, 1): | |
| txt = _strip_ws(c.get("text","")) | |
| txt = re.sub(r"\s+", " ", txt) | |
| txt = _truncate_chars(txt, per_chunk_max) | |
| title = _strip_ws(c.get("title") or "") | |
| url = _strip_ws(c.get("source_url") or "") | |
| links.append(url) | |
| if title: | |
| head = f"[{i}] {title}" | |
| else: | |
| head = f"[{i}] 参考 {i}" | |
| if url: | |
| head += f" <{url}>" | |
| lines.append(head + "\n" + txt) | |
| return "\n\n".join(lines), links | |
| def _chat_once(messages: List[Dict], max_tokens=700, temperature=0.2) -> Tuple[str, str]: | |
| client = _openai_client() | |
| resp = client.chat.completions.create( | |
| model=CHAT_MODEL, | |
| messages=messages, | |
| temperature=float(temperature), | |
| max_tokens=int(max_tokens), | |
| ) | |
| choice = resp.choices[0] | |
| content = (choice.message.content or "").strip() | |
| reason = choice.finish_reason or "stop" | |
| return content, reason | |
| def _complete_with_continuations(messages: List[Dict], max_tokens=700, temperature=0.2, | |
| max_rounds=4, hard_cap_chars=8000) -> str: | |
| """ | |
| finish_reason が length の間は「続けてください」を自動で投げて結合。 | |
| 念のため全体文字数に上限(hard_cap_chars)をかける。 | |
| """ | |
| out, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature) | |
| rounds = 1 | |
| while reason == "length" and rounds < max_rounds and len(out) < hard_cap_chars: | |
| messages = messages + [ | |
| {"role":"assistant", "content": out[-1200:]}, # 直前の末尾だけ渡す(文脈維持+トークン節約) | |
| {"role":"user", "content": "続けてください。直前の続きから、重複なく簡潔に出力してください。"} | |
| ] | |
| nxt, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature) | |
| if not nxt: | |
| break | |
| out = (out + "\n" + nxt).strip() | |
| rounds += 1 | |
| return out[:hard_cap_chars].strip() | |
| def _postfix_if_bare(ans: str) -> str: | |
| """ | |
| 末尾が句読点等で終わっていない場合に句点を補う(CSV で切れて見えづらい対策)。 | |
| """ | |
| s = _strip_ws(ans) | |
| if not s: | |
| return s | |
| if re.search(r"[。.!?!?\)]\s*$", s): | |
| return s | |
| # 参照記号 [1] 等で終わる場合はOK | |
| if re.search(r"\[\d+\]\s*$", s): | |
| return s | |
| return s + "。" | |
| # ===== 質問候補の生成 ===== | |
| QUESTION_GUIDE = _dedent(""" | |
| あなたは日本の上場企業のIR資料に詳しいアナリストです。 | |
| 以下のトピックについて、投資家・メディアから想定される質問を日本語で生成してください。 | |
| - 決算ハイライト、通期見通し、セグメント別動向、費用/粗利、キャッシュフロー、投資計画、配当/自社株買い、ESG、リスク、質疑応答で深掘りされやすい論点など | |
| 出力形式は、各行に1問だけのプレーンテキスト。番号は付けません。 | |
| 質問は簡潔に(1行80文字以内)、具体的に、事実確認ではなく説明を引き出す聞き方にしてください。 | |
| """) | |
| def _propose_questions(query: str, n: int) -> List[str]: | |
| base_msgs = [ | |
| {"role":"system", "content": "あなたは有能なIRアナリストです。"}, | |
| {"role":"user", "content": QUESTION_GUIDE + f"\n\n対象トピック:\n{query}\n\n必要な件数: {n}問"} | |
| ] | |
| text = _complete_with_continuations(base_msgs, max_tokens=700, temperature=0.2, max_rounds=2) | |
| qs = [q.strip(" ・-—\t") for q in text.splitlines() if q.strip()] | |
| # 行頭の番号や記号を除去 | |
| cleaned: List[str] = [] | |
| for q in qs: | |
| q = re.sub(r"^\d+[\).、]\s*", "", q) | |
| q = re.sub(r"^[・\-—]\s*", "", q) | |
| if q and q not in cleaned: | |
| cleaned.append(q) | |
| return cleaned[:n] if len(cleaned) >= n else cleaned | |
| # ===== 回答生成 ===== | |
| ANSWER_SYS = _dedent(""" | |
| あなたは上場企業のIR担当者のつもりで、与えられた「資料抜粋(Sources)」だけを根拠に、日本語で誠実に回答します。 | |
| - 回答は5〜10文(または箇条書き5〜8点)で、具体的な数値・期間・要因を入れてください。 | |
| - 資料にない推測はしません。情報が不足していれば「判明分」と「不明点」を分けて述べます。 | |
| - 引用は [1], [2] 形式で付けます(対応するSource番号)。同一文末に複数可。 | |
| - 最後に1文で簡潔に要約してください。 | |
| """).strip() | |
| def _answer_one(question: str, top_k=8) -> Tuple[str, List[str]]: | |
| # 関連チャンク検索 | |
| chunks = rag_search(question, top_k=top_k) | |
| # 取りすぎると長くなるので上位を採用 | |
| chunks = (chunks or [])[:top_k] | |
| sources_text, links = _build_sources_block(chunks, per_chunk_max=500) | |
| # コンテキストが少ない/空なら、ユーザー質問に基づく一般的テンプレで埋める(ただし「一般論」明記) | |
| if not sources_text.strip(): | |
| sources_text = "(該当資料無し)" | |
| # メッセージ組み立て | |
| prompt = _dedent(f""" | |
| <質問> | |
| {question} | |
| <資料抜粋(Sources)> | |
| {sources_text} | |
| <指示> | |
| - 上記 Sources だけを根拠に回答。根拠となった箇所の番号を [n] で明示。 | |
| - 200〜600字程度を目安に、冗長な導入は避け、結論から書く。 | |
| - 数字や固有名詞は元資料に合わせる。 | |
| - 不足があれば不足点を最後に1行で注記。 | |
| """) | |
| messages = [ | |
| {"role":"system", "content": ANSWER_SYS}, | |
| {"role":"user", "content": prompt} | |
| ] | |
| # 途中打ち切りを自動で継続 | |
| ans = _complete_with_continuations(messages, max_tokens=900, temperature=0.2, max_rounds=3, hard_cap_chars=4000) | |
| ans = _postfix_if_bare(ans) | |
| return ans, links | |
| # ===== エクスポートAPI ===== | |
| def make_qa(query: str, n: int = 30) -> Tuple[List[Dict], List[str]]: | |
| """ | |
| 返り値: | |
| qa_list: [{"q": str, "a": str}, ...] | |
| links: 重複排除したURL一覧(参考リンク用) | |
| """ | |
| # まず質問候補を出す | |
| qs = _propose_questions(query, n) | |
| if not qs: | |
| # 最低限のフォールバック | |
| qs = [f"{_strip_ws(query)}の四半期業績の増減要因は?", | |
| "通期見通しの前提(為替、コスト、数量)は?", | |
| "セグメント別の業績動向と主要KPIの見通しは?", | |
| "資本政策(配当方針/自社株買い)とその根拠は?", | |
| "主なリスクと対応策は?"][:n] | |
| qa_list: List[Dict] = [] | |
| all_links: List[str] = [] | |
| for q in qs[:n]: | |
| a, links = _answer_one(q, top_k=8) | |
| qa_list.append({"q": q, "a": a}) | |
| all_links.extend(links or []) | |
| # 重複除去・順序維持 | |
| uniq_links = list(dict.fromkeys([u for u in all_links if u])) | |
| return qa_list, uniq_links | |