RAG_LangGraph / src /agent.py
Anandharajan's picture
Sync Space with LangGraph RAG app
99f19b3
from typing import List, Optional, TypedDict
from types import SimpleNamespace
import requests
from langgraph.graph import StateGraph, END
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from .config import HF_MODEL_ID, HF_API_TOKEN, LOCAL_MODEL_ID, TEMPERATURE
# Cache local model/pipeline to avoid repeated downloads.
_LOCAL_PIPELINE = None
_LOCAL_MODEL_ID = None
def _build_prompt(question: str, docs: List) -> str:
"""Create a concise prompt that uses retrieved context."""
context = "\n\n".join(d.page_content for d in docs[:4])
return (
"You are a helpful assistant. Use the provided context to answer the question. "
"If the context is insufficient, say you do not know.\n\n"
f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
)
class ChatState(TypedDict):
messages: List[BaseMessage]
context: str
def _hf_generate(prompt: str, model_id: str, token: Optional[str], temperature: float) -> str:
"""
Minimal text generation call against the Hugging Face router API.
"""
url = f"https://router.huggingface.co/models/{model_id}"
headers = {"Accept": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 512,
"temperature": temperature,
"return_full_text": False,
},
}
try:
resp = requests.post(url, headers=headers, json=payload, timeout=60)
resp.raise_for_status()
except requests.HTTPError as http_err:
status = http_err.response.status_code if http_err.response is not None else None
if status == 404:
raise RuntimeError(
f"Model '{model_id}' not found on Hugging Face router. "
f"Set HF_MODEL_ID to a router-available text-generation model and retry."
) from http_err
raise
except requests.RequestException as req_err:
# Network layer issues (timeouts, DNS, etc.) should surface cleanly so we can fall back.
raise RuntimeError(f"Hugging Face router request failed: {req_err}") from req_err
data = resp.json()
# HF router can return list or dict; handle both
if isinstance(data, list) and data and isinstance(data[0], dict):
if "generated_text" in data[0]:
return data[0]["generated_text"]
if "error" in data[0]:
raise RuntimeError(data[0]["error"])
if isinstance(data, dict):
if "generated_text" in data:
return data["generated_text"]
if "error" in data:
raise RuntimeError(data["error"])
return str(data)
def _local_generate(prompt: str, model_id: str, temperature: float) -> str:
"""
Fallback local generation using transformers pipeline (no HF API token needed).
Truncates the prompt to fit within the model's max position embeddings to avoid index errors.
"""
global _LOCAL_PIPELINE, _LOCAL_MODEL_ID
if _LOCAL_PIPELINE is None or _LOCAL_MODEL_ID != model_id:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
_LOCAL_PIPELINE = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="cpu",
)
_LOCAL_MODEL_ID = model_id
tokenizer = _LOCAL_PIPELINE.tokenizer
model = _LOCAL_PIPELINE.model
max_new_tokens = 128
# Determine max prompt length to prevent IndexError for small context windows (e.g., gpt2 = 1024).
max_positions = getattr(getattr(model, "config", None), "max_position_embeddings", None)
pad_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id
if max_positions and isinstance(max_positions, int):
allowed = max_positions - max_new_tokens - 1
if allowed > 0:
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
if len(input_ids) > allowed:
# Keep the tail of the prompt (most recent question + context)
input_ids = input_ids[-allowed:]
prompt = tokenizer.decode(input_ids, skip_special_tokens=True)
outputs = _LOCAL_PIPELINE(
prompt,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature,
pad_token_id=pad_token_id,
)
# transformers pipeline returns list of dicts
if outputs and isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
return outputs[0]["generated_text"]
return str(outputs)
def build_agent(
vectorstore,
hf_model_id: Optional[str] = None,
hf_api_token: Optional[str] = None,
temperature: Optional[float] = None,
):
"""
Simple RAG agent using Hugging Face router inference (text_generation).
"""
retriever = vectorstore.as_retriever()
model_id = (hf_model_id or HF_MODEL_ID).strip()
local_model_id = (LOCAL_MODEL_ID or model_id).strip()
token = (hf_api_token or HF_API_TOKEN or "").strip() or None
temp = TEMPERATURE if temperature is None else temperature
def invoke(payload):
messages = payload.get("messages", [])
user_content = messages[-1].content if messages else ""
# prefer invoke to avoid deprecation warnings
if hasattr(retriever, "invoke"):
docs = retriever.invoke(user_content)
else:
docs = retriever.get_relevant_documents(user_content)
prompt = _build_prompt(user_content, docs)
# Use router if a token is provided; otherwise fall back to local generation.
try:
if token:
text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
else:
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
except Exception as api_err:
if token:
# Degrade gracefully to local generation when router is flaky or the model is blocked.
fallback_note = (
f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
)
print(fallback_note)
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
text = f"{text}\n\n{fallback_note}"
else:
raise
return {"messages": [AIMessage(content=text)]}
# Return an object with an invoke method to mirror previous agent_executor shape
return SimpleNamespace(invoke=invoke)
def build_langgraph_agent(
vectorstore,
hf_model_id: Optional[str] = None,
hf_api_token: Optional[str] = None,
temperature: Optional[float] = None,
):
"""
LangGraph-based RAG agent with retrieval + generation nodes.
"""
retriever = vectorstore.as_retriever()
model_id = (hf_model_id or HF_MODEL_ID).strip()
local_model_id = (LOCAL_MODEL_ID or model_id).strip()
token = (hf_api_token or HF_API_TOKEN or "").strip() or None
temp = TEMPERATURE if temperature is None else temperature
def retrieve_node(state: ChatState):
messages = state.get("messages", [])
user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
query = user_msg.content if user_msg else ""
if hasattr(retriever, "invoke"):
docs = retriever.invoke(query)
else:
docs = retriever.get_relevant_documents(query)
context = "\n\n".join(d.page_content for d in docs[:4])
return {"context": context}
def generate_node(state: ChatState):
messages = state.get("messages", [])
context = state.get("context", "")
user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
question = user_msg.content if user_msg else ""
prompt = (
"You are a helpful assistant. Use the provided context to answer the question. "
"If the context is insufficient, say you do not know.\n\n"
f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
)
try:
if token:
text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
else:
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
except Exception as api_err:
if token:
fallback_note = (
f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
)
print(fallback_note)
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
text = f"{text}\n\n{fallback_note}"
else:
raise
return {"messages": messages + [AIMessage(content=text)]}
graph = StateGraph(ChatState)
graph.add_node("retrieve", retrieve_node)
graph.add_node("generate", generate_node)
graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "generate")
graph.add_edge("generate", END)
app = graph.compile()
# Wrap to mirror the previous agent_executor interface for Gradio.
def invoke(payload):
incoming_messages = payload.get("messages", [])
initial_state: ChatState = {"messages": incoming_messages, "context": ""}
return app.invoke(initial_state)
return SimpleNamespace(invoke=invoke)