File size: 2,259 Bytes
a53be7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import List

from dotenv import load_dotenv

from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage


# ================================
# Settings / Env
# ================================
@dataclass(frozen=True)
class Settings:
    google_api_key: str
    file_search_store_name: str
    model_name: str = "gemini-2.5-flash"


def load_settings() -> Settings:
    load_dotenv()

    key = (os.getenv("GOOGLE_API_KEY") or "").strip()
    if not (key and len(key) > 10):
        raise RuntimeError(
            "GOOGLE_API_KEY not found (or looks wrong). Put it in .env as GOOGLE_API_KEY=..."
        )

    store = (os.getenv("FILE_SEARCH_STORE_NAME") or "").strip()
    if not store:
        raise RuntimeError(
            "FILE_SEARCH_STORE_NAME not found. Put it in .env as FILE_SEARCH_STORE_NAME=fileSearchStores/..."
        )

    model = (os.getenv("GEMINI_MODEL") or "gemini-2.5-flash").strip()
    return Settings(google_api_key=key, file_search_store_name=store, model_name=model)


# ================================
# Memory helpers (keep last ~N words)
# ================================
def _count_words(msg) -> int:
    text = getattr(msg, "content", "") or ""
    if isinstance(text, list):
        text = " ".join(
            [b.get("text", "") if isinstance(b, dict) else str(b) for b in text]
        )
    return len(str(text).split())


def trim_history_to_words(
    chat_history: ChatMessageHistory,
    max_words: int = 5000,
) -> List:
    msgs = chat_history.messages
    total = 0
    kept = []
    for m in reversed(msgs):
        w = _count_words(m)
        if total + w > max_words:
            break
        kept.append(m)
        total += w
    return list(reversed(kept))


def build_transcript(past_messages: List) -> str:
    lines = []
    for m in past_messages:
        role = "User" if isinstance(m, HumanMessage) else "Assistant"
        lines.append(f"{role}: {getattr(m, 'content', '')}")
    return "\n".join(lines).strip()


def load_prompt(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        return f.read().strip()