Corin1998 commited on
Commit
c1a58c7
·
verified ·
1 Parent(s): 63a5dee

Update generators/qa.py

Browse files
Files changed (1) hide show
  1. generators/qa.py +44 -22
generators/qa.py CHANGED
@@ -1,28 +1,50 @@
1
- import json
2
- from rag.retriever import retrieve, format_citations
3
- from rag.prompts import QA_SYS, QA_USER
4
- from irpr.deps import generate_chat
5
 
6
- def make_qa(query: str, n=30):
7
- hits = retrieve(query, top_k=16)
8
- links = format_citations(hits)
9
- link_lines = "\n".join([f"[{i}] {u}" for i, u in links])
10
- ctx = "\n".join([f"[chunk:{h['chunk_id']}] {h['text']}" for h in hits])
11
 
12
- messages = [
13
- {"role": "system", "content": QA_SYS},
14
- {"role": "user", "content": QA_USER.format(links=link_lines, contexts=ctx)}
15
- ]
16
- raw = generate_chat(messages, max_new_tokens=1800, temperature=0.2)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
- data = json.loads(raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  except Exception:
21
- data = []
22
- for line in raw.splitlines():
23
- if "Q:" in line and "A:" in line:
24
- q = line.split("Q:", 1)[1].split("A:", 1)[0].strip()
25
- a = line.split("A:", 1)[1].strip()
26
- data.append({"q": q, "a": a, "sources": []})
27
 
28
- return data[:n], links
 
 
 
 
 
 
 
1
+ # generators/qa.py
2
+ from __future__ import annotations
3
+ from typing import List, Tuple
4
+ from irpr.deps import search, generate_chat
5
 
6
+ SYS = "あなたは日本語のIR担当です。投資家からの想定質問と模範回答を、根拠に基づいて簡潔に作成します。"
 
 
 
 
7
 
8
+ TPL = """次の抜粋を根拠に、投資家向けの想定Q&Aを {n} 問作成してください。
9
+ 各問は "Q: ... / A: ..." の2行で、Aは2-4文以内。根拠があれば括弧で短く示してください。
 
 
 
10
 
11
+ # 抜粋
12
+ {context}
13
+ """
14
+
15
+ def make_qa(query: str, n: int = 30) -> Tuple[List[dict], List[str]]:
16
+ hits = search(query, top_k=min(12, n))
17
+ links, ctx = [], []
18
+ for i, h in enumerate(hits, 1):
19
+ src = h.get("source_url") or ""
20
+ if src and src not in links: links.append(src)
21
+ ctx.append(f"[{i}] {h.get('title') or ''} {src}\n{h['text'][:1000]}")
22
+ context = "\n\n".join(ctx) if ctx else "(根拠なし)"
23
+
24
+ # LLM試行
25
+ qa_list: List[dict] = []
26
  try:
27
+ out = generate_chat(
28
+ [{"role":"system","content":SYS},
29
+ {"role":"user","content":TPL.format(n=n, context=context)}],
30
+ max_new_tokens=1000
31
+ )
32
+ for line in out.splitlines():
33
+ line = line.strip()
34
+ if line.startswith("Q:"):
35
+ qa_list.append({"q": line[2:].strip(), "a": ""})
36
+ elif line.startswith("A:") and qa_list:
37
+ qa_list[-1]["a"] = line[2:].strip()
38
+ qa_list = [x for x in qa_list if x.get("q") and x.get("a")]
39
+ if qa_list:
40
+ return qa_list[:n], links
41
  except Exception:
42
+ pass
 
 
 
 
 
43
 
44
+ # フォールバック(抽出)
45
+ for h in hits[:n]:
46
+ qa_list.append({
47
+ "q": f"{(h.get('title') or '決算トピック')}のポイントは?",
48
+ "a": (h['text'][:240] + "…").replace("\n"," "),
49
+ })
50
+ return qa_list[:n], links