from typing import List, Optional, TypedDict from types import SimpleNamespace import requests from langgraph.graph import StateGraph, END from langchain_core.messages import AIMessage, BaseMessage, HumanMessage from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from .config import HF_MODEL_ID, HF_API_TOKEN, LOCAL_MODEL_ID, TEMPERATURE # Cache local model/pipeline to avoid repeated downloads. _LOCAL_PIPELINE = None _LOCAL_MODEL_ID = None def _build_prompt(question: str, docs: List) -> str: """Create a concise prompt that uses retrieved context.""" context = "\n\n".join(d.page_content for d in docs[:4]) return ( "You are a helpful assistant. Use the provided context to answer the question. " "If the context is insufficient, say you do not know.\n\n" f"Context:\n{context}\n\nQuestion: {question}\nAnswer:" ) class ChatState(TypedDict): messages: List[BaseMessage] context: str def _hf_generate(prompt: str, model_id: str, token: Optional[str], temperature: float) -> str: """ Minimal text generation call against the Hugging Face router API. """ url = f"https://router.huggingface.co/models/{model_id}" headers = {"Accept": "application/json"} if token: headers["Authorization"] = f"Bearer {token}" payload = { "inputs": prompt, "parameters": { "max_new_tokens": 512, "temperature": temperature, "return_full_text": False, }, } try: resp = requests.post(url, headers=headers, json=payload, timeout=60) resp.raise_for_status() except requests.HTTPError as http_err: status = http_err.response.status_code if http_err.response is not None else None if status == 404: raise RuntimeError( f"Model '{model_id}' not found on Hugging Face router. " f"Set HF_MODEL_ID to a router-available text-generation model and retry." ) from http_err raise except requests.RequestException as req_err: # Network layer issues (timeouts, DNS, etc.) should surface cleanly so we can fall back. raise RuntimeError(f"Hugging Face router request failed: {req_err}") from req_err data = resp.json() # HF router can return list or dict; handle both if isinstance(data, list) and data and isinstance(data[0], dict): if "generated_text" in data[0]: return data[0]["generated_text"] if "error" in data[0]: raise RuntimeError(data[0]["error"]) if isinstance(data, dict): if "generated_text" in data: return data["generated_text"] if "error" in data: raise RuntimeError(data["error"]) return str(data) def _local_generate(prompt: str, model_id: str, temperature: float) -> str: """ Fallback local generation using transformers pipeline (no HF API token needed). Truncates the prompt to fit within the model's max position embeddings to avoid index errors. """ global _LOCAL_PIPELINE, _LOCAL_MODEL_ID if _LOCAL_PIPELINE is None or _LOCAL_MODEL_ID != model_id: tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) _LOCAL_PIPELINE = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map="cpu", ) _LOCAL_MODEL_ID = model_id tokenizer = _LOCAL_PIPELINE.tokenizer model = _LOCAL_PIPELINE.model max_new_tokens = 128 # Determine max prompt length to prevent IndexError for small context windows (e.g., gpt2 = 1024). max_positions = getattr(getattr(model, "config", None), "max_position_embeddings", None) pad_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id if max_positions and isinstance(max_positions, int): allowed = max_positions - max_new_tokens - 1 if allowed > 0: input_ids = tokenizer.encode(prompt, add_special_tokens=False) if len(input_ids) > allowed: # Keep the tail of the prompt (most recent question + context) input_ids = input_ids[-allowed:] prompt = tokenizer.decode(input_ids, skip_special_tokens=True) outputs = _LOCAL_PIPELINE( prompt, max_new_tokens=max_new_tokens, do_sample=temperature > 0, temperature=temperature, pad_token_id=pad_token_id, ) # transformers pipeline returns list of dicts if outputs and isinstance(outputs[0], dict) and "generated_text" in outputs[0]: return outputs[0]["generated_text"] return str(outputs) def build_agent( vectorstore, hf_model_id: Optional[str] = None, hf_api_token: Optional[str] = None, temperature: Optional[float] = None, ): """ Simple RAG agent using Hugging Face router inference (text_generation). """ retriever = vectorstore.as_retriever() model_id = (hf_model_id or HF_MODEL_ID).strip() local_model_id = (LOCAL_MODEL_ID or model_id).strip() token = (hf_api_token or HF_API_TOKEN or "").strip() or None temp = TEMPERATURE if temperature is None else temperature def invoke(payload): messages = payload.get("messages", []) user_content = messages[-1].content if messages else "" # prefer invoke to avoid deprecation warnings if hasattr(retriever, "invoke"): docs = retriever.invoke(user_content) else: docs = retriever.get_relevant_documents(user_content) prompt = _build_prompt(user_content, docs) # Use router if a token is provided; otherwise fall back to local generation. try: if token: text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp) else: text = _local_generate(prompt, model_id=local_model_id, temperature=temp) except Exception as api_err: if token: # Degrade gracefully to local generation when router is flaky or the model is blocked. fallback_note = ( f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]" ) print(fallback_note) text = _local_generate(prompt, model_id=local_model_id, temperature=temp) text = f"{text}\n\n{fallback_note}" else: raise return {"messages": [AIMessage(content=text)]} # Return an object with an invoke method to mirror previous agent_executor shape return SimpleNamespace(invoke=invoke) def build_langgraph_agent( vectorstore, hf_model_id: Optional[str] = None, hf_api_token: Optional[str] = None, temperature: Optional[float] = None, ): """ LangGraph-based RAG agent with retrieval + generation nodes. """ retriever = vectorstore.as_retriever() model_id = (hf_model_id or HF_MODEL_ID).strip() local_model_id = (LOCAL_MODEL_ID or model_id).strip() token = (hf_api_token or HF_API_TOKEN or "").strip() or None temp = TEMPERATURE if temperature is None else temperature def retrieve_node(state: ChatState): messages = state.get("messages", []) user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None) query = user_msg.content if user_msg else "" if hasattr(retriever, "invoke"): docs = retriever.invoke(query) else: docs = retriever.get_relevant_documents(query) context = "\n\n".join(d.page_content for d in docs[:4]) return {"context": context} def generate_node(state: ChatState): messages = state.get("messages", []) context = state.get("context", "") user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None) question = user_msg.content if user_msg else "" prompt = ( "You are a helpful assistant. Use the provided context to answer the question. " "If the context is insufficient, say you do not know.\n\n" f"Context:\n{context}\n\nQuestion: {question}\nAnswer:" ) try: if token: text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp) else: text = _local_generate(prompt, model_id=local_model_id, temperature=temp) except Exception as api_err: if token: fallback_note = ( f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]" ) print(fallback_note) text = _local_generate(prompt, model_id=local_model_id, temperature=temp) text = f"{text}\n\n{fallback_note}" else: raise return {"messages": messages + [AIMessage(content=text)]} graph = StateGraph(ChatState) graph.add_node("retrieve", retrieve_node) graph.add_node("generate", generate_node) graph.set_entry_point("retrieve") graph.add_edge("retrieve", "generate") graph.add_edge("generate", END) app = graph.compile() # Wrap to mirror the previous agent_executor interface for Gradio. def invoke(payload): incoming_messages = payload.get("messages", []) initial_state: ChatState = {"messages": incoming_messages, "context": ""} return app.invoke(initial_state) return SimpleNamespace(invoke=invoke)