| from __future__ import annotations |
| from typing import List, Dict, Any |
| import os |
| import re |
| import json |
| from huggingface_hub import InferenceClient |
| from transformers import pipeline |
|
|
| _CLIENT = None |
| _PIPE = None |
|
|
| def _client(): |
| global _CLIENT |
| tok = os.getenv("HF_TOKEN") |
| model_id = os.getenv("GEN_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.3") |
| if tok: |
| if _CLIENT is None: |
| _CLIENT = InferenceClient(model=model_id, token=tok, timeout=120) |
| return _CLIENT |
| return None |
|
|
| def _local_pipe(): |
| global _PIPE |
| if _PIPE is None: |
| |
| model_id = os.getenv("LOCAL_MODEL_ID", "google/flan-t5-base") |
| _PIPE = pipeline("text2text-generation", model=model_id) |
| return _PIPE |
|
|
| SYSTEM_PROMPT = ( |
| "あなたは旅行プランナーです。入力の条件と要約用ソースから、" |
| "日本語で簡潔かつ実用的な要約をJSONのみで出力します。" |
| "同じ文やフレーズを繰り返さないでください。" |
| "ソース本文の長い引用や配列・トークン列・表の生出力は禁止です。" |
| ) |
|
|
| def _call_chat(prompt: str, max_new_tokens: int = 512) -> str: |
| c = _client() |
| if c: |
| msg = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| ] |
| out = c.chat_completion(messages=msg, max_tokens=max_new_tokens, temperature=0.4) |
| return out.choices[0].message["content"] |
| else: |
| pipe = _local_pipe() |
| out = pipe(prompt, max_new_tokens=max_new_tokens, do_sample=False) |
| return out[0]["generated_text"] |
|
|
| def _json_sentinel_strip(text: str) -> str: |
| |
| if not text: |
| return "{}" |
| i = text.find("{") |
| j = text.rfind("}") |
| if i == -1 or j == -1 or j <= i: |
| return "{}" |
| return text[i:j+1] |
|
|
| def generate_structured_summary(query: str, passages: List[dict], weather: dict, prefs: dict) -> Dict[str, Any]: |
| """ |
| 期待JSON: |
| { |
| "title": "短いタイトル", |
| "highlights": ["3-6個の箇条書き"], |
| "memos": ["3-6個の実用メモ(予約/混雑/雨天代替/費用など)"] |
| } |
| """ |
| ctx = "\n\n".join( |
| [f"[Source {i+1}] {p.get('meta', {}).get('title') or p.get('meta', {}).get('type')} — {p.get('text','')[:300]}" |
| for i, p in enumerate(passages)] |
| ) |
| prompt = f""" |
| 出力は**必ず**次のJSONのみ。説明文や前置きは一切禁止。 |
| {{ |
| "title": "(10〜25文字程度の日本語)", |
| "highlights": ["3〜6件。短い日本語。重複禁止。"], |
| "memos": ["3〜6件。短い日本語。重複禁止。"] |
| }} |
| |
| ユーザー条件: {json.dumps(prefs, ensure_ascii=False)} |
| 天候: {json.dumps(weather, ensure_ascii=False)} |
| クエリ: {query} |
| |
| ソース(要点のみ反映。本文のコピペ/配列/座標/トークン列の出力は**厳禁**): |
| {ctx} |
| """ |
| raw = _call_chat(prompt, max_new_tokens=600) |
| js = _json_sentinel_strip(raw) |
| try: |
| data = json.loads(js) |
| except Exception: |
| |
| data = {"title": "観光プランの要約", "highlights": [], "memos": []} |
|
|
| |
| def _uniq(seq): |
| seen = set(); out = [] |
| for s in (seq or []): |
| s = (s or "").strip() |
| if not s: continue |
| if s in seen: continue |
| seen.add(s); out.append(s) |
| return out[:6] |
|
|
| data["title"] = (data.get("title") or "観光プランの要約").strip() |
| data["highlights"] = _uniq(data.get("highlights", [])) |
| data["memos"] = _uniq(data.get("memos", [])) |
| return data |
|
|
| |
|
|
| _SENT_SPLIT = re.compile(r"(?<=[。.!!??])\s*") |
|
|
| def postprocess_text_ja(text: str) -> str: |
| if not text: |
| return "" |
| t = re.sub(r"[ \t]+", " ", text).strip() |
|
|
| |
| sents = [s.strip() for s in _SENT_SPLIT.split(t) if s.strip()] |
| seen_s, out_s = set(), [] |
| for s in sents: |
| |
| letters = len(re.findall(r"[A-Za-z0-9一-龥ぁ-んァ-ヶー]", s)) |
| if letters == 0 or letters / max(1, len(s)) < 0.4: |
| continue |
| if s in seen_s: |
| continue |
| seen_s.add(s) |
| out_s.append(s) |
| t2 = "".join(out_s) |
|
|
| |
| lines = [l.rstrip() for l in t2.splitlines()] |
| seen_l, out_l = set(), [] |
| for l in lines: |
| if l.strip() == "": |
| if out_l and out_l[-1] == "": |
| continue |
| out_l.append("") |
| continue |
| if l in seen_l: |
| continue |
| seen_l.add(l) |
| out_l.append(l) |
| return "\n".join(out_l).strip() |
|
|