Corin1998's picture
Update services/rag_chain.py
0613879 verified
from __future__ import annotations
from typing import List, Dict, Any
import os
import re
import json
from huggingface_hub import InferenceClient
from transformers import pipeline
_CLIENT = None
_PIPE = None
def _client():
global _CLIENT
tok = os.getenv("HF_TOKEN")
model_id = os.getenv("GEN_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.3")
if tok:
if _CLIENT is None:
_CLIENT = InferenceClient(model=model_id, token=tok, timeout=120)
return _CLIENT
return None
def _local_pipe():
global _PIPE
if _PIPE is None:
# 軽量フォールバック(JSON生成精度は下がるため後段で厳しめにパース)
model_id = os.getenv("LOCAL_MODEL_ID", "google/flan-t5-base")
_PIPE = pipeline("text2text-generation", model=model_id)
return _PIPE
SYSTEM_PROMPT = (
"あなたは旅行プランナーです。入力の条件と要約用ソースから、"
"日本語で簡潔かつ実用的な要約をJSONのみで出力します。"
"同じ文やフレーズを繰り返さないでください。"
"ソース本文の長い引用や配列・トークン列・表の生出力は禁止です。"
)
def _call_chat(prompt: str, max_new_tokens: int = 512) -> str:
c = _client()
if c:
msg = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
]
out = c.chat_completion(messages=msg, max_tokens=max_new_tokens, temperature=0.4)
return out.choices[0].message["content"]
else:
pipe = _local_pipe()
out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False)
return out[0]["generated_text"]
def _json_sentinel_strip(text: str) -> str:
# 最初の{〜最後の}の間だけを取り出してJSONパースを試みる
if not text:
return "{}"
i = text.find("{")
j = text.rfind("}")
if i == -1 or j == -1 or j <= i:
return "{}"
return text[i:j+1]
def generate_structured_summary(query: str, passages: List[dict], weather: dict, prefs: dict) -> Dict[str, Any]:
"""
期待JSON:
{
"title": "短いタイトル",
"highlights": ["3-6個の箇条書き"],
"memos": ["3-6個の実用メモ(予約/混雑/雨天代替/費用など)"]
}
"""
ctx = "\n\n".join(
[f"[Source {i+1}] {p.get('meta', {}).get('title') or p.get('meta', {}).get('type')} — {p.get('text','')[:300]}"
for i, p in enumerate(passages)]
)
prompt = f"""
出力は**必ず**次のJSONのみ。説明文や前置きは一切禁止。
{{
"title": "(10〜25文字程度の日本語)",
"highlights": ["3〜6件。短い日本語。重複禁止。"],
"memos": ["3〜6件。短い日本語。重複禁止。"]
}}
ユーザー条件: {json.dumps(prefs, ensure_ascii=False)}
天候: {json.dumps(weather, ensure_ascii=False)}
クエリ: {query}
ソース(要点のみ反映。本文のコピペ/配列/座標/トークン列の出力は**厳禁**):
{ctx}
"""
raw = _call_chat(prompt, max_new_tokens=600)
js = _json_sentinel_strip(raw)
try:
data = json.loads(js)
except Exception:
# フォールバック(最低限の骨組み)
data = {"title": "観光プランの要約", "highlights": [], "memos": []}
# 基本整形・重複除去
def _uniq(seq):
seen = set(); out = []
for s in (seq or []):
s = (s or "").strip()
if not s: continue
if s in seen: continue
seen.add(s); out.append(s)
return out[:6]
data["title"] = (data.get("title") or "観光プランの要約").strip()
data["highlights"] = _uniq(data.get("highlights", []))
data["memos"] = _uniq(data.get("memos", []))
return data
# --- 後処理(日本語の簡易クリーニング/重複除去) ---
_SENT_SPLIT = re.compile(r"(?<=[。.!!??])\s*")
def postprocess_text_ja(text: str) -> str:
if not text:
return ""
t = re.sub(r"[ \t]+", " ", text).strip()
# 文単位の重複除去
sents = [s.strip() for s in _SENT_SPLIT.split(t) if s.strip()]
seen_s, out_s = set(), []
for s in sents:
# 記号ばかりの行(ノイズ)を除去
letters = len(re.findall(r"[A-Za-z0-9一-龥ぁ-んァ-ヶー]", s))
if letters == 0 or letters / max(1, len(s)) < 0.4:
continue
if s in seen_s:
continue
seen_s.add(s)
out_s.append(s)
t2 = "".join(out_s)
# 行単位の重複・連続空行の抑制
lines = [l.rstrip() for l in t2.splitlines()]
seen_l, out_l = set(), []
for l in lines:
if l.strip() == "":
if out_l and out_l[-1] == "":
continue
out_l.append("")
continue
if l in seen_l:
continue
seen_l.add(l)
out_l.append(l)
return "\n".join(out_l).strip()