Spaces:
Sleeping
Sleeping
File size: 9,589 Bytes
99f19b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
from typing import List, Optional, TypedDict
from types import SimpleNamespace
import requests
from langgraph.graph import StateGraph, END
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from .config import HF_MODEL_ID, HF_API_TOKEN, LOCAL_MODEL_ID, TEMPERATURE
# Cache local model/pipeline to avoid repeated downloads.
_LOCAL_PIPELINE = None
_LOCAL_MODEL_ID = None
def _build_prompt(question: str, docs: List) -> str:
"""Create a concise prompt that uses retrieved context."""
context = "\n\n".join(d.page_content for d in docs[:4])
return (
"You are a helpful assistant. Use the provided context to answer the question. "
"If the context is insufficient, say you do not know.\n\n"
f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
)
class ChatState(TypedDict):
messages: List[BaseMessage]
context: str
def _hf_generate(prompt: str, model_id: str, token: Optional[str], temperature: float) -> str:
"""
Minimal text generation call against the Hugging Face router API.
"""
url = f"https://router.huggingface.co/models/{model_id}"
headers = {"Accept": "application/json"}
if token:
headers["Authorization"] = f"Bearer {token}"
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 512,
"temperature": temperature,
"return_full_text": False,
},
}
try:
resp = requests.post(url, headers=headers, json=payload, timeout=60)
resp.raise_for_status()
except requests.HTTPError as http_err:
status = http_err.response.status_code if http_err.response is not None else None
if status == 404:
raise RuntimeError(
f"Model '{model_id}' not found on Hugging Face router. "
f"Set HF_MODEL_ID to a router-available text-generation model and retry."
) from http_err
raise
except requests.RequestException as req_err:
# Network layer issues (timeouts, DNS, etc.) should surface cleanly so we can fall back.
raise RuntimeError(f"Hugging Face router request failed: {req_err}") from req_err
data = resp.json()
# HF router can return list or dict; handle both
if isinstance(data, list) and data and isinstance(data[0], dict):
if "generated_text" in data[0]:
return data[0]["generated_text"]
if "error" in data[0]:
raise RuntimeError(data[0]["error"])
if isinstance(data, dict):
if "generated_text" in data:
return data["generated_text"]
if "error" in data:
raise RuntimeError(data["error"])
return str(data)
def _local_generate(prompt: str, model_id: str, temperature: float) -> str:
"""
Fallback local generation using transformers pipeline (no HF API token needed).
Truncates the prompt to fit within the model's max position embeddings to avoid index errors.
"""
global _LOCAL_PIPELINE, _LOCAL_MODEL_ID
if _LOCAL_PIPELINE is None or _LOCAL_MODEL_ID != model_id:
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
_LOCAL_PIPELINE = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="cpu",
)
_LOCAL_MODEL_ID = model_id
tokenizer = _LOCAL_PIPELINE.tokenizer
model = _LOCAL_PIPELINE.model
max_new_tokens = 128
# Determine max prompt length to prevent IndexError for small context windows (e.g., gpt2 = 1024).
max_positions = getattr(getattr(model, "config", None), "max_position_embeddings", None)
pad_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id
if max_positions and isinstance(max_positions, int):
allowed = max_positions - max_new_tokens - 1
if allowed > 0:
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
if len(input_ids) > allowed:
# Keep the tail of the prompt (most recent question + context)
input_ids = input_ids[-allowed:]
prompt = tokenizer.decode(input_ids, skip_special_tokens=True)
outputs = _LOCAL_PIPELINE(
prompt,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0,
temperature=temperature,
pad_token_id=pad_token_id,
)
# transformers pipeline returns list of dicts
if outputs and isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
return outputs[0]["generated_text"]
return str(outputs)
def build_agent(
vectorstore,
hf_model_id: Optional[str] = None,
hf_api_token: Optional[str] = None,
temperature: Optional[float] = None,
):
"""
Simple RAG agent using Hugging Face router inference (text_generation).
"""
retriever = vectorstore.as_retriever()
model_id = (hf_model_id or HF_MODEL_ID).strip()
local_model_id = (LOCAL_MODEL_ID or model_id).strip()
token = (hf_api_token or HF_API_TOKEN or "").strip() or None
temp = TEMPERATURE if temperature is None else temperature
def invoke(payload):
messages = payload.get("messages", [])
user_content = messages[-1].content if messages else ""
# prefer invoke to avoid deprecation warnings
if hasattr(retriever, "invoke"):
docs = retriever.invoke(user_content)
else:
docs = retriever.get_relevant_documents(user_content)
prompt = _build_prompt(user_content, docs)
# Use router if a token is provided; otherwise fall back to local generation.
try:
if token:
text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
else:
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
except Exception as api_err:
if token:
# Degrade gracefully to local generation when router is flaky or the model is blocked.
fallback_note = (
f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
)
print(fallback_note)
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
text = f"{text}\n\n{fallback_note}"
else:
raise
return {"messages": [AIMessage(content=text)]}
# Return an object with an invoke method to mirror previous agent_executor shape
return SimpleNamespace(invoke=invoke)
def build_langgraph_agent(
vectorstore,
hf_model_id: Optional[str] = None,
hf_api_token: Optional[str] = None,
temperature: Optional[float] = None,
):
"""
LangGraph-based RAG agent with retrieval + generation nodes.
"""
retriever = vectorstore.as_retriever()
model_id = (hf_model_id or HF_MODEL_ID).strip()
local_model_id = (LOCAL_MODEL_ID or model_id).strip()
token = (hf_api_token or HF_API_TOKEN or "").strip() or None
temp = TEMPERATURE if temperature is None else temperature
def retrieve_node(state: ChatState):
messages = state.get("messages", [])
user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
query = user_msg.content if user_msg else ""
if hasattr(retriever, "invoke"):
docs = retriever.invoke(query)
else:
docs = retriever.get_relevant_documents(query)
context = "\n\n".join(d.page_content for d in docs[:4])
return {"context": context}
def generate_node(state: ChatState):
messages = state.get("messages", [])
context = state.get("context", "")
user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
question = user_msg.content if user_msg else ""
prompt = (
"You are a helpful assistant. Use the provided context to answer the question. "
"If the context is insufficient, say you do not know.\n\n"
f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
)
try:
if token:
text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
else:
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
except Exception as api_err:
if token:
fallback_note = (
f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
)
print(fallback_note)
text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
text = f"{text}\n\n{fallback_note}"
else:
raise
return {"messages": messages + [AIMessage(content=text)]}
graph = StateGraph(ChatState)
graph.add_node("retrieve", retrieve_node)
graph.add_node("generate", generate_node)
graph.set_entry_point("retrieve")
graph.add_edge("retrieve", "generate")
graph.add_edge("generate", END)
app = graph.compile()
# Wrap to mirror the previous agent_executor interface for Gradio.
def invoke(payload):
incoming_messages = payload.get("messages", [])
initial_state: ChatState = {"messages": incoming_messages, "context": ""}
return app.invoke(initial_state)
return SimpleNamespace(invoke=invoke)
|