File size: 9,589 Bytes
99f19b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from typing import List, Optional, TypedDict
from types import SimpleNamespace
import requests
from langgraph.graph import StateGraph, END
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from .config import HF_MODEL_ID, HF_API_TOKEN, LOCAL_MODEL_ID, TEMPERATURE

# Cache local model/pipeline to avoid repeated downloads.
_LOCAL_PIPELINE = None
_LOCAL_MODEL_ID = None


def _build_prompt(question: str, docs: List) -> str:
    """Create a concise prompt that uses retrieved context."""
    context = "\n\n".join(d.page_content for d in docs[:4])
    return (
        "You are a helpful assistant. Use the provided context to answer the question. "
        "If the context is insufficient, say you do not know.\n\n"
        f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
    )


class ChatState(TypedDict):
    messages: List[BaseMessage]
    context: str


def _hf_generate(prompt: str, model_id: str, token: Optional[str], temperature: float) -> str:
    """
    Minimal text generation call against the Hugging Face router API.
    """
    url = f"https://router.huggingface.co/models/{model_id}"
    headers = {"Accept": "application/json"}
    if token:
        headers["Authorization"] = f"Bearer {token}"
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 512,
            "temperature": temperature,
            "return_full_text": False,
        },
    }
    try:
        resp = requests.post(url, headers=headers, json=payload, timeout=60)
        resp.raise_for_status()
    except requests.HTTPError as http_err:
        status = http_err.response.status_code if http_err.response is not None else None
        if status == 404:
            raise RuntimeError(
                f"Model '{model_id}' not found on Hugging Face router. "
                f"Set HF_MODEL_ID to a router-available text-generation model and retry."
            ) from http_err
        raise
    except requests.RequestException as req_err:
        # Network layer issues (timeouts, DNS, etc.) should surface cleanly so we can fall back.
        raise RuntimeError(f"Hugging Face router request failed: {req_err}") from req_err
    data = resp.json()
    # HF router can return list or dict; handle both
    if isinstance(data, list) and data and isinstance(data[0], dict):
        if "generated_text" in data[0]:
            return data[0]["generated_text"]
        if "error" in data[0]:
            raise RuntimeError(data[0]["error"])
    if isinstance(data, dict):
        if "generated_text" in data:
            return data["generated_text"]
        if "error" in data:
            raise RuntimeError(data["error"])
    return str(data)


def _local_generate(prompt: str, model_id: str, temperature: float) -> str:
    """
    Fallback local generation using transformers pipeline (no HF API token needed).
    Truncates the prompt to fit within the model's max position embeddings to avoid index errors.
    """
    global _LOCAL_PIPELINE, _LOCAL_MODEL_ID

    if _LOCAL_PIPELINE is None or _LOCAL_MODEL_ID != model_id:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(model_id)
        _LOCAL_PIPELINE = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            device_map="cpu",
        )
        _LOCAL_MODEL_ID = model_id

    tokenizer = _LOCAL_PIPELINE.tokenizer
    model = _LOCAL_PIPELINE.model
    max_new_tokens = 128

    # Determine max prompt length to prevent IndexError for small context windows (e.g., gpt2 = 1024).
    max_positions = getattr(getattr(model, "config", None), "max_position_embeddings", None)
    pad_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id
    if max_positions and isinstance(max_positions, int):
        allowed = max_positions - max_new_tokens - 1
        if allowed > 0:
            input_ids = tokenizer.encode(prompt, add_special_tokens=False)
            if len(input_ids) > allowed:
                # Keep the tail of the prompt (most recent question + context)
                input_ids = input_ids[-allowed:]
                prompt = tokenizer.decode(input_ids, skip_special_tokens=True)

    outputs = _LOCAL_PIPELINE(
        prompt,
        max_new_tokens=max_new_tokens,
        do_sample=temperature > 0,
        temperature=temperature,
        pad_token_id=pad_token_id,
    )
    # transformers pipeline returns list of dicts
    if outputs and isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
        return outputs[0]["generated_text"]
    return str(outputs)


def build_agent(
    vectorstore,
    hf_model_id: Optional[str] = None,
    hf_api_token: Optional[str] = None,
    temperature: Optional[float] = None,
):
    """
    Simple RAG agent using Hugging Face router inference (text_generation).
    """
    retriever = vectorstore.as_retriever()
    model_id = (hf_model_id or HF_MODEL_ID).strip()
    local_model_id = (LOCAL_MODEL_ID or model_id).strip()
    token = (hf_api_token or HF_API_TOKEN or "").strip() or None
    temp = TEMPERATURE if temperature is None else temperature

    def invoke(payload):
        messages = payload.get("messages", [])
        user_content = messages[-1].content if messages else ""

        # prefer invoke to avoid deprecation warnings
        if hasattr(retriever, "invoke"):
            docs = retriever.invoke(user_content)
        else:
            docs = retriever.get_relevant_documents(user_content)
        prompt = _build_prompt(user_content, docs)
        # Use router if a token is provided; otherwise fall back to local generation.
        try:
            if token:
                text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
            else:
                text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
        except Exception as api_err:
            if token:
                # Degrade gracefully to local generation when router is flaky or the model is blocked.
                fallback_note = (
                    f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
                )
                print(fallback_note)
                text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
                text = f"{text}\n\n{fallback_note}"
            else:
                raise
        return {"messages": [AIMessage(content=text)]}

    # Return an object with an invoke method to mirror previous agent_executor shape
    return SimpleNamespace(invoke=invoke)


def build_langgraph_agent(
    vectorstore,
    hf_model_id: Optional[str] = None,
    hf_api_token: Optional[str] = None,
    temperature: Optional[float] = None,
):
    """
    LangGraph-based RAG agent with retrieval + generation nodes.
    """
    retriever = vectorstore.as_retriever()
    model_id = (hf_model_id or HF_MODEL_ID).strip()
    local_model_id = (LOCAL_MODEL_ID or model_id).strip()
    token = (hf_api_token or HF_API_TOKEN or "").strip() or None
    temp = TEMPERATURE if temperature is None else temperature

    def retrieve_node(state: ChatState):
        messages = state.get("messages", [])
        user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
        query = user_msg.content if user_msg else ""

        if hasattr(retriever, "invoke"):
            docs = retriever.invoke(query)
        else:
            docs = retriever.get_relevant_documents(query)
        context = "\n\n".join(d.page_content for d in docs[:4])
        return {"context": context}

    def generate_node(state: ChatState):
        messages = state.get("messages", [])
        context = state.get("context", "")
        user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
        question = user_msg.content if user_msg else ""

        prompt = (
            "You are a helpful assistant. Use the provided context to answer the question. "
            "If the context is insufficient, say you do not know.\n\n"
            f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
        )

        try:
            if token:
                text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
            else:
                text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
        except Exception as api_err:
            if token:
                fallback_note = (
                    f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
                )
                print(fallback_note)
                text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
                text = f"{text}\n\n{fallback_note}"
            else:
                raise
        return {"messages": messages + [AIMessage(content=text)]}

    graph = StateGraph(ChatState)
    graph.add_node("retrieve", retrieve_node)
    graph.add_node("generate", generate_node)
    graph.set_entry_point("retrieve")
    graph.add_edge("retrieve", "generate")
    graph.add_edge("generate", END)

    app = graph.compile()

    # Wrap to mirror the previous agent_executor interface for Gradio.
    def invoke(payload):
        incoming_messages = payload.get("messages", [])
        initial_state: ChatState = {"messages": incoming_messages, "context": ""}
        return app.invoke(initial_state)

    return SimpleNamespace(invoke=invoke)