File size: 8,349 Bytes
c1a58c7
 
2a333a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# generators/qa.py
from __future__ import annotations
from typing import List, Dict, Tuple
import os, re, textwrap
from irpr.config import settings
from irpr.deps import search as rag_search

# OpenAI 直呼び(finish_reason を見たいので deps.generate_chat は使わない)
def _openai_client():
    from openai import OpenAI
    key = os.environ.get("OPENAI_API_KEY", "").strip()
    if not key:
        raise RuntimeError("OPENAI_API_KEY が未設定です。環境変数に設定してください。")
    return OpenAI(api_key=key)

CHAT_MODEL = os.environ.get("OPENAI_CHAT_MODEL", settings.OPENAI_CHAT_MODEL)

# ===== ユーティリティ =====

def _truncate_chars(s: str, max_chars: int) -> str:
    s = (s or "").strip()
    if len(s) <= max_chars:
        return s
    return s[:max_chars].rstrip() + "…"

def _strip_ws(s: str) -> str:
    return (s or "").replace("\u3000", " ").strip()

def _dedent(s: str) -> str:
    return textwrap.dedent(s).strip()

def _build_sources_block(chunks: List[Dict], per_chunk_max=600) -> Tuple[str, List[str]]:
    """
    参照用に [1] [2] 形式で短く並べる。各チャンクは最大 per_chunk_max 文字に丸める。
    返り値: (sources_text, links)
    """
    lines = []
    links: List[str] = []
    for i, c in enumerate(chunks, 1):
        txt = _strip_ws(c.get("text",""))
        txt = re.sub(r"\s+", " ", txt)
        txt = _truncate_chars(txt, per_chunk_max)
        title = _strip_ws(c.get("title") or "")
        url = _strip_ws(c.get("source_url") or "")
        links.append(url)
        if title:
            head = f"[{i}] {title}"
        else:
            head = f"[{i}] 参考 {i}"
        if url:
            head += f" <{url}>"
        lines.append(head + "\n" + txt)
    return "\n\n".join(lines), links

def _chat_once(messages: List[Dict], max_tokens=700, temperature=0.2) -> Tuple[str, str]:
    client = _openai_client()
    resp = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=messages,
        temperature=float(temperature),
        max_tokens=int(max_tokens),
    )
    choice = resp.choices[0]
    content = (choice.message.content or "").strip()
    reason  = choice.finish_reason or "stop"
    return content, reason

def _complete_with_continuations(messages: List[Dict], max_tokens=700, temperature=0.2,
                                 max_rounds=4, hard_cap_chars=8000) -> str:
    """
    finish_reason が length の間は「続けてください」を自動で投げて結合。
    念のため全体文字数に上限(hard_cap_chars)をかける。
    """
    out, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
    rounds = 1
    while reason == "length" and rounds < max_rounds and len(out) < hard_cap_chars:
        messages = messages + [
            {"role":"assistant", "content": out[-1200:]},  # 直前の末尾だけ渡す(文脈維持+トークン節約)
            {"role":"user", "content": "続けてください。直前の続きから、重複なく簡潔に出力してください。"}
        ]
        nxt, reason = _chat_once(messages, max_tokens=max_tokens, temperature=temperature)
        if not nxt:
            break
        out = (out + "\n" + nxt).strip()
        rounds += 1
    return out[:hard_cap_chars].strip()

def _postfix_if_bare(ans: str) -> str:
    """
    末尾が句読点等で終わっていない場合に句点を補う(CSV で切れて見えづらい対策)。
    """
    s = _strip_ws(ans)
    if not s:
        return s
    if re.search(r"[。.!?!?\)]\s*$", s):
        return s
    # 参照記号 [1] 等で終わる場合はOK
    if re.search(r"\[\d+\]\s*$", s):
        return s
    return s + "。"

# ===== 質問候補の生成 =====

QUESTION_GUIDE = _dedent("""
あなたは日本の上場企業のIR資料に詳しいアナリストです。
以下のトピックについて、投資家・メディアから想定される質問を日本語で生成してください。
- 決算ハイライト、通期見通し、セグメント別動向、費用/粗利、キャッシュフロー、投資計画、配当/自社株買い、ESG、リスク、質疑応答で深掘りされやすい論点など
出力形式は、各行に1問だけのプレーンテキスト。番号は付けません。
質問は簡潔に(1行80文字以内)、具体的に、事実確認ではなく説明を引き出す聞き方にしてください。
""")

def _propose_questions(query: str, n: int) -> List[str]:
    base_msgs = [
        {"role":"system", "content": "あなたは有能なIRアナリストです。"},
        {"role":"user", "content": QUESTION_GUIDE + f"\n\n対象トピック:\n{query}\n\n必要な件数: {n}問"}
    ]
    text = _complete_with_continuations(base_msgs, max_tokens=700, temperature=0.2, max_rounds=2)
    qs = [q.strip(" ・-—\t") for q in text.splitlines() if q.strip()]
    # 行頭の番号や記号を除去
    cleaned: List[str] = []
    for q in qs:
        q = re.sub(r"^\d+[\).、]\s*", "", q)
        q = re.sub(r"^[・\-—]\s*", "", q)
        if q and q not in cleaned:
            cleaned.append(q)
    return cleaned[:n] if len(cleaned) >= n else cleaned

# ===== 回答生成 =====

ANSWER_SYS = _dedent("""
あなたは上場企業のIR担当者のつもりで、与えられた「資料抜粋(Sources)」だけを根拠に、日本語で誠実に回答します。
- 回答は5〜10文(または箇条書き5〜8点)で、具体的な数値・期間・要因を入れてください。
- 資料にない推測はしません。情報が不足していれば「判明分」と「不明点」を分けて述べます。
- 引用は [1], [2] 形式で付けます(対応するSource番号)。同一文末に複数可。
- 最後に1文で簡潔に要約してください。
""").strip()

def _answer_one(question: str, top_k=8) -> Tuple[str, List[str]]:
    # 関連チャンク検索
    chunks = rag_search(question, top_k=top_k)
    # 取りすぎると長くなるので上位を採用
    chunks = (chunks or [])[:top_k]
    sources_text, links = _build_sources_block(chunks, per_chunk_max=500)

    # コンテキストが少ない/空なら、ユーザー質問に基づく一般的テンプレで埋める(ただし「一般論」明記)
    if not sources_text.strip():
        sources_text = "(該当資料無し)"

    # メッセージ組み立て
    prompt = _dedent(f"""
    <質問>
    {question}

    <資料抜粋(Sources)>
    {sources_text}

    <指示>
    - 上記 Sources だけを根拠に回答。根拠となった箇所の番号を [n] で明示。
    - 200〜600字程度を目安に、冗長な導入は避け、結論から書く。
    - 数字や固有名詞は元資料に合わせる。
    - 不足があれば不足点を最後に1行で注記。
    """)
    messages = [
        {"role":"system", "content": ANSWER_SYS},
        {"role":"user", "content": prompt}
    ]
    # 途中打ち切りを自動で継続
    ans = _complete_with_continuations(messages, max_tokens=900, temperature=0.2, max_rounds=3, hard_cap_chars=4000)
    ans = _postfix_if_bare(ans)
    return ans, links

# ===== エクスポートAPI =====

def make_qa(query: str, n: int = 30) -> Tuple[List[Dict], List[str]]:
    """
    返り値:
      qa_list: [{"q": str, "a": str}, ...]
      links:   重複排除したURL一覧(参考リンク用)
    """
    # まず質問候補を出す
    qs = _propose_questions(query, n)
    if not qs:
        # 最低限のフォールバック
        qs = [f"{_strip_ws(query)}の四半期業績の増減要因は?",
              "通期見通しの前提(為替、コスト、数量)は?",
              "セグメント別の業績動向と主要KPIの見通しは?",
              "資本政策(配当方針/自社株買い)とその根拠は?",
              "主なリスクと対応策は?"][:n]

    qa_list: List[Dict] = []
    all_links: List[str] = []
    for q in qs[:n]:
        a, links = _answer_one(q, top_k=8)
        qa_list.append({"q": q, "a": a})
        all_links.extend(links or [])

    # 重複除去・順序維持
    uniq_links = list(dict.fromkeys([u for u in all_links if u]))

    return qa_list, uniq_links