Anandharajan commited on
Commit
99f19b3
·
1 Parent(s): 12e4ce3

Sync Space with LangGraph RAG app

Browse files
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.github/workflows/deploy-space.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to Hugging Face Space
2
+
3
+ on:
4
+ push:
5
+ branches: [ master ]
6
+ tags: [ 'v*' ]
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-space:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ - name: Push to Hugging Face Space
15
+ uses: huggingface/hub-action@v1
16
+ with:
17
+ repo-token: ${{ secrets.HF_TOKEN }}
18
+ repo-id: ${{ secrets.HF_SPACE_ID }}
19
+ repo-type: space
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ .DS_Store
4
+ .venv/
5
+
6
+ # Local artifacts
7
+ data/source.pdf
8
+ data/faiss_index/
9
+
10
+ # Gradio upload temp files (just in case)
11
+ tmp/
README.md CHANGED
@@ -1,16 +1,69 @@
1
- ---
2
- title: RAG LangGraph
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.42.0
8
- app_file: app.py
9
- pinned: false
10
- hf_oauth: true
11
- hf_oauth_scopes:
12
- - inference-api
13
- short_description: ' A LangGraph-powered RAG chatbot '
14
- ---
15
-
16
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAG-Based Chatbot (LangGraph + Hugging Face)
2
+
3
+ This project implements a RAG (Retrieval-Augmented Generation) chatbot that answers with either:
4
+ - **Hugging Face router** (when you provide an HF token and a router-available model; default `HF_MODEL_ID`: `meta-llama/Meta-Llama-3-8B-Instruct`), or
5
+ - **Local transformers generation** (no token; fallback `LOCAL_MODEL_ID`: `distilgpt2` by default — quality is limited; set a stronger local model if you need better offline answers).
6
+
7
+ ## Features
8
+ - **RAG Pipeline**: Ingests, chunks, embeds, and indexes PDF documents for accurate retrieval.
9
+ - **Inference Flexibility**: Uses HF router when a token is provided; falls back to local transformers otherwise.
10
+ - **LangGraph Agent**: Retrieval + generation flow is orchestrated with LangGraph for clearer state handling.
11
+ - **Gradio Interface**: A user-friendly chat UI for interacting with the assistant.
12
+ - **Modular Design**: Clean separation of concerns (Ingestion, Vector Store, Agent, App).
13
+
14
+ ## Project Structure
15
+ ```
16
+ rag_agent_project/
17
+ ├─ app.py # Gradio application
18
+ ├─ requirements.txt # Dependencies
19
+ ├─ data/ # Data storage (PDFs, Index)
20
+ ├─ src/ # Source code
21
+ │ ├─ ingestion.py # Data processing
22
+ │ ├─ vectorstore.py # Embedding & Indexing
23
+ │ ├─ rag_tool.py # (legacy) retriever tool helper
24
+ │ ├─ agent.py # RAG + HF router/local agent
25
+ │ └─ config.py # Configuration
26
+ └─ tests/ # Automated tests
27
+ ```
28
+
29
+ ## Setup & Usage
30
+
31
+ 1. **Install Dependencies**:
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ 2. **Configure (optional)**:
37
+ - Set `HUGGINGFACEHUB_API_TOKEN` for router inference.
38
+ - Override `HF_MODEL_ID` for router (default: `meta-llama/Meta-Llama-3-8B-Instruct`).
39
+ - Override `LOCAL_MODEL_ID` for local fallback (default: `distilgpt2`; use a stronger local model if you need better offline answers).
40
+
41
+ 3. **Run the Application**:
42
+ ```bash
43
+ python app.py
44
+ ```
45
+
46
+ 4. **Interact**:
47
+ - Open the provided local URL (usually `http://127.0.0.1:7860`).
48
+ - (Optional) Provide a Hugging Face token and router-supported model ID for cloud inference (default: `meta-llama/Meta-Llama-3-8B-Instruct`).
49
+ - Without a token, the app uses a local fallback model (`LOCAL_MODEL_ID`, default: `distilgpt2`; quality is limited—use router + token for good answers or set a stronger local model).
50
+ - Upload a PDF and click "Initialize System".
51
+ - Start chatting!
52
+
53
+ ## Deployment (Hugging Face Spaces)
54
+ 1. Create a new Space on Hugging Face (SDK: Gradio).
55
+ 2. Upload the contents of `rag_agent_project` to the Space.
56
+ 3. Ensure `requirements.txt` is present.
57
+ 4. The app will build and launch automatically.
58
+
59
+ ## Technical Details
60
+ - **LLM**: HF router (with token, default `meta-llama/Meta-Llama-3-8B-Instruct`) or local transformers fallback (`LOCAL_MODEL_ID`, default `distilgpt2`; change to a stronger model if running locally).
61
+ - **Embeddings**: sentence-transformers/all-MiniLM-L6-v2
62
+ - **Vector Store**: FAISS
63
+ - **Orchestration**: LangGraph (retrieve → generate) RAG prompt with retrieval context
64
+
65
+ ## Notes for Hugging Face Spaces
66
+ - Add your `HUGGINGFACEHUB_API_TOKEN` as a secret for router usage.
67
+ - If you want to pin a different router model, set `HF_MODEL_ID` in the Space variables. Override `LOCAL_MODEL_ID` if you want a specific offline fallback.
68
+ - The `data/` folder is persisted for uploads and FAISS index; it is git-ignored here but created at runtime.
69
+ - Entry point is `app.py`; `demo.queue().launch()` is enabled for Spaces concurrency.
app.py CHANGED
@@ -1,70 +1,180 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
1
  import gradio as gr
2
+ import os
3
+ import shutil
4
+ from src.config import PDF_PATH, HF_API_TOKEN, HF_MODEL_ID, DATA_DIR
5
+ from src.ingestion import ingest_file
6
+ from src.vectorstore import create_vectorstore, load_vectorstore
7
+ from src.agent import build_langgraph_agent
8
+ from langchain_core.messages import HumanMessage
9
+
10
+ # Global variables to store state
11
+ vectorstore = None
12
+ agent_executor = None
13
+ current_hf_token = None
14
+ current_hf_model = None
15
+
16
+ # Ensure data directory exists for uploads and FAISS index (important for HF Spaces).
17
+ os.makedirs(DATA_DIR, exist_ok=True)
18
+
19
+
20
+ def _get_uploaded_path(uploaded_file):
21
+ """
22
+ Normalize Gradio's uploaded file into a filesystem path.
23
+ Handles filepath strings, temporary file objects, and dict payloads.
24
+ """
25
+ if uploaded_file is None:
26
+ return None
27
+
28
+ if isinstance(uploaded_file, (str, os.PathLike)):
29
+ return str(uploaded_file)
30
+
31
+ if isinstance(uploaded_file, dict):
32
+ return uploaded_file.get("name") or uploaded_file.get("path")
33
+
34
+ if hasattr(uploaded_file, "name"):
35
+ return uploaded_file.name
36
+
37
+ return None
38
+
39
+
40
+ def initialize_system(hf_token, hf_model, uploaded_file):
41
  """
42
+ Initializes the RAG pipeline and Agent.
43
  """
44
+ global vectorstore, agent_executor, current_hf_token, current_hf_model
45
+
46
+ hf_token = (hf_token or HF_API_TOKEN or "").strip()
47
+ hf_model = (hf_model or HF_MODEL_ID).strip()
48
+ uploaded_path = _get_uploaded_path(uploaded_file)
49
+
50
+ if uploaded_file is not None and uploaded_path is None:
51
+ return "Could not read the uploaded file. Please try uploading again."
52
+
53
+ if uploaded_path is None and not os.path.exists(PDF_PATH):
54
+ return "Please upload a PDF file."
55
+
56
+ try:
57
+ # 0. Handle File Upload
58
+ if uploaded_path is not None:
59
+ # Gradio passes a temporary file path or a file object depending on version/config.
60
+ # Usually it's a named temp file path in recent versions.
61
+ # We copy it to our data directory.
62
+ if not os.path.exists(os.path.dirname(PDF_PATH)):
63
+ os.makedirs(os.path.dirname(PDF_PATH))
64
+
65
+ # uploaded_file is a file path in recent Gradio versions
66
+ shutil.copy(uploaded_path, PDF_PATH)
67
+ print(f"File saved to {PDF_PATH}")
68
+
69
+ # Force re-ingestion since we have a new file
70
+ print("Ingesting PDF...")
71
+ chunks = ingest_file(str(PDF_PATH))
72
+ vectorstore = create_vectorstore(chunks)
73
+
74
+ # 1. Load or Create Vector Store (if not already created above)
75
+ if vectorstore is None:
76
+ vectorstore = load_vectorstore()
77
+ if vectorstore is None:
78
+ # This case should be covered by the upload logic, but just in case
79
+ if os.path.exists(PDF_PATH):
80
+ print("Ingesting PDF...")
81
+ chunks = ingest_file(str(PDF_PATH))
82
+ vectorstore = create_vectorstore(chunks)
83
+ else:
84
+ return "Source PDF not found. Please upload a file."
85
+
86
+ # 2. Create Agent (LangGraph)
87
+ agent_executor = build_langgraph_agent(vectorstore, hf_api_token=hf_token, hf_model_id=hf_model)
88
+ current_hf_token = hf_token
89
+ current_hf_model = hf_model
90
+ mode = "Hugging Face router" if hf_token else "local transformers (no HF token provided)"
91
+
92
+ return f"System Initialized Successfully using {mode}. You can now start chatting."
93
+ except Exception as e:
94
+ import traceback
95
+ traceback.print_exc()
96
+ return f"Initialization Failed: {str(e)}"
97
+
98
+ def chat(message, history, hf_token, hf_model, uploaded_file):
99
+ """
100
+ Chat function for Gradio.
101
+ """
102
+ global agent_executor, current_hf_token, current_hf_model
103
+
104
+ # Gradio can pass None for history on the first turn.
105
+ history = history or []
106
+ if not message:
107
+ return "Please enter a message to start chatting."
108
+
109
+ hf_token = (hf_token or HF_API_TOKEN or "").strip()
110
+ hf_model = (hf_model or HF_MODEL_ID).strip()
111
+
112
+ # Check if API key has changed or agent is not initialized
113
+ if agent_executor is None or hf_token != current_hf_token or hf_model != current_hf_model:
114
+ init_msg = initialize_system(hf_token, hf_model, uploaded_file)
115
+ if "Failed" in init_msg or "Please" in init_msg:
116
+ return init_msg
117
+
118
+ # Run the agent
119
+ try:
120
+ # Convert history to LangChain format if needed, but LangGraph handles state.
121
+ # We pass the full history + new message to the agent if we were managing state manually,
122
+ # but here we'll just pass the new message and let the graph handle it if we were persistent.
123
+ # For a simple chat interface without persistence, we pass the conversation history.
124
+
125
+ messages = []
126
+ for h in history:
127
+ messages.append(HumanMessage(content=h[0]))
128
+ # We would need AI message here too, but Gradio history is [user, bot].
129
+ # For simplicity in this demo, we'll just send the current message or a limited context.
130
+ # Let's send the current message. To support history, we'd need to map Gradio history to LangChain messages.
131
+
132
+ # Better approach for this demo: Just send the current message.
133
+ # The agent is stateless between calls in this simple implementation unless we use checkpointers.
134
+
135
+ response = agent_executor.invoke({"messages": [HumanMessage(content=message)]})
136
+ return response["messages"][-1].content
137
+ except Exception as e:
138
+ import traceback
139
+ traceback.print_exc()
140
+ hint = (
141
+ " If you used the Hugging Face router, verify the token/model. "
142
+ "Otherwise, try re-initializing to refresh the vector store."
143
+ )
144
+ return f"Error while generating a reply: {str(e)}{hint}"
145
+
146
+ # Gradio UI
147
+ with gr.Blocks(title="RAG Chatbot (LangGraph + HF)") as demo:
148
+ gr.Markdown("# RAG-Based Chatbot (LangGraph + Hugging Face)")
149
+ gr.Markdown(
150
+ "Upload a PDF, build a vector store, retrieve context, and answer with either the Hugging Face router "
151
+ "(when a token + router model is provided) or a local fallback model."
152
+ )
153
+
154
+ with gr.Row():
155
+ api_key_input = gr.Textbox(
156
+ label="Hugging Face API Token (optional)",
157
+ type="password",
158
+ placeholder="hf_...",
159
+ value=os.getenv("HUGGINGFACEHUB_API_TOKEN", "")
160
+ )
161
+ model_input = gr.Textbox(
162
+ label="Model ID",
163
+ placeholder="e.g. meta-llama/Meta-Llama-3-8B-Instruct",
164
+ value=os.getenv("HF_MODEL_ID", HF_MODEL_ID),
165
+ )
166
+ file_input = gr.File(label="Upload PDF", file_types=[".pdf"], type="filepath")
167
+ init_btn = gr.Button("Initialize System")
168
+
169
+ status_output = gr.Textbox(label="Status", interactive=False)
170
+
171
+ chatbot = gr.ChatInterface(
172
+ fn=chat,
173
+ additional_inputs=[api_key_input, model_input, file_input]
174
+ )
175
 
176
+ init_btn.click(initialize_system, inputs=[api_key_input, model_input, file_input], outputs=[status_output])
177
 
178
  if __name__ == "__main__":
179
+ # Use local launch by default; share links can fail without network access.
180
+ demo.queue().launch(share=False)
data/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ This directory stores uploaded PDFs and the generated FAISS index at runtime.
2
+ These files are ignored in version control to keep the repo lightweight for GitHub and Hugging Face Spaces.
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain==0.3.7
2
+ langchain-community==0.3.7
3
+ langchain-text-splitters==0.3.2
4
+ langchain-huggingface==0.1.2
5
+ langgraph==0.2.39
6
+ gradio==4.44.1
7
+ python-dotenv==1.0.1
8
+ sentence-transformers==2.6.1
9
+ faiss-cpu==1.7.4
10
+ pypdf==4.2.0
11
+ pydantic==2.9.2
12
+ huggingface-hub==0.23.4
13
+ transformers>=4.37.0
src/__init__.py ADDED
File without changes
src/agent.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, TypedDict
2
+ from types import SimpleNamespace
3
+ import requests
4
+ from langgraph.graph import StateGraph, END
5
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
6
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
7
+ from .config import HF_MODEL_ID, HF_API_TOKEN, LOCAL_MODEL_ID, TEMPERATURE
8
+
9
+ # Cache local model/pipeline to avoid repeated downloads.
10
+ _LOCAL_PIPELINE = None
11
+ _LOCAL_MODEL_ID = None
12
+
13
+
14
+ def _build_prompt(question: str, docs: List) -> str:
15
+ """Create a concise prompt that uses retrieved context."""
16
+ context = "\n\n".join(d.page_content for d in docs[:4])
17
+ return (
18
+ "You are a helpful assistant. Use the provided context to answer the question. "
19
+ "If the context is insufficient, say you do not know.\n\n"
20
+ f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
21
+ )
22
+
23
+
24
+ class ChatState(TypedDict):
25
+ messages: List[BaseMessage]
26
+ context: str
27
+
28
+
29
+ def _hf_generate(prompt: str, model_id: str, token: Optional[str], temperature: float) -> str:
30
+ """
31
+ Minimal text generation call against the Hugging Face router API.
32
+ """
33
+ url = f"https://router.huggingface.co/models/{model_id}"
34
+ headers = {"Accept": "application/json"}
35
+ if token:
36
+ headers["Authorization"] = f"Bearer {token}"
37
+ payload = {
38
+ "inputs": prompt,
39
+ "parameters": {
40
+ "max_new_tokens": 512,
41
+ "temperature": temperature,
42
+ "return_full_text": False,
43
+ },
44
+ }
45
+ try:
46
+ resp = requests.post(url, headers=headers, json=payload, timeout=60)
47
+ resp.raise_for_status()
48
+ except requests.HTTPError as http_err:
49
+ status = http_err.response.status_code if http_err.response is not None else None
50
+ if status == 404:
51
+ raise RuntimeError(
52
+ f"Model '{model_id}' not found on Hugging Face router. "
53
+ f"Set HF_MODEL_ID to a router-available text-generation model and retry."
54
+ ) from http_err
55
+ raise
56
+ except requests.RequestException as req_err:
57
+ # Network layer issues (timeouts, DNS, etc.) should surface cleanly so we can fall back.
58
+ raise RuntimeError(f"Hugging Face router request failed: {req_err}") from req_err
59
+ data = resp.json()
60
+ # HF router can return list or dict; handle both
61
+ if isinstance(data, list) and data and isinstance(data[0], dict):
62
+ if "generated_text" in data[0]:
63
+ return data[0]["generated_text"]
64
+ if "error" in data[0]:
65
+ raise RuntimeError(data[0]["error"])
66
+ if isinstance(data, dict):
67
+ if "generated_text" in data:
68
+ return data["generated_text"]
69
+ if "error" in data:
70
+ raise RuntimeError(data["error"])
71
+ return str(data)
72
+
73
+
74
+ def _local_generate(prompt: str, model_id: str, temperature: float) -> str:
75
+ """
76
+ Fallback local generation using transformers pipeline (no HF API token needed).
77
+ Truncates the prompt to fit within the model's max position embeddings to avoid index errors.
78
+ """
79
+ global _LOCAL_PIPELINE, _LOCAL_MODEL_ID
80
+
81
+ if _LOCAL_PIPELINE is None or _LOCAL_MODEL_ID != model_id:
82
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
83
+ model = AutoModelForCausalLM.from_pretrained(model_id)
84
+ _LOCAL_PIPELINE = pipeline(
85
+ "text-generation",
86
+ model=model,
87
+ tokenizer=tokenizer,
88
+ device_map="cpu",
89
+ )
90
+ _LOCAL_MODEL_ID = model_id
91
+
92
+ tokenizer = _LOCAL_PIPELINE.tokenizer
93
+ model = _LOCAL_PIPELINE.model
94
+ max_new_tokens = 128
95
+
96
+ # Determine max prompt length to prevent IndexError for small context windows (e.g., gpt2 = 1024).
97
+ max_positions = getattr(getattr(model, "config", None), "max_position_embeddings", None)
98
+ pad_token_id = tokenizer.eos_token_id or tokenizer.pad_token_id
99
+ if max_positions and isinstance(max_positions, int):
100
+ allowed = max_positions - max_new_tokens - 1
101
+ if allowed > 0:
102
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False)
103
+ if len(input_ids) > allowed:
104
+ # Keep the tail of the prompt (most recent question + context)
105
+ input_ids = input_ids[-allowed:]
106
+ prompt = tokenizer.decode(input_ids, skip_special_tokens=True)
107
+
108
+ outputs = _LOCAL_PIPELINE(
109
+ prompt,
110
+ max_new_tokens=max_new_tokens,
111
+ do_sample=temperature > 0,
112
+ temperature=temperature,
113
+ pad_token_id=pad_token_id,
114
+ )
115
+ # transformers pipeline returns list of dicts
116
+ if outputs and isinstance(outputs[0], dict) and "generated_text" in outputs[0]:
117
+ return outputs[0]["generated_text"]
118
+ return str(outputs)
119
+
120
+
121
+ def build_agent(
122
+ vectorstore,
123
+ hf_model_id: Optional[str] = None,
124
+ hf_api_token: Optional[str] = None,
125
+ temperature: Optional[float] = None,
126
+ ):
127
+ """
128
+ Simple RAG agent using Hugging Face router inference (text_generation).
129
+ """
130
+ retriever = vectorstore.as_retriever()
131
+ model_id = (hf_model_id or HF_MODEL_ID).strip()
132
+ local_model_id = (LOCAL_MODEL_ID or model_id).strip()
133
+ token = (hf_api_token or HF_API_TOKEN or "").strip() or None
134
+ temp = TEMPERATURE if temperature is None else temperature
135
+
136
+ def invoke(payload):
137
+ messages = payload.get("messages", [])
138
+ user_content = messages[-1].content if messages else ""
139
+
140
+ # prefer invoke to avoid deprecation warnings
141
+ if hasattr(retriever, "invoke"):
142
+ docs = retriever.invoke(user_content)
143
+ else:
144
+ docs = retriever.get_relevant_documents(user_content)
145
+ prompt = _build_prompt(user_content, docs)
146
+ # Use router if a token is provided; otherwise fall back to local generation.
147
+ try:
148
+ if token:
149
+ text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
150
+ else:
151
+ text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
152
+ except Exception as api_err:
153
+ if token:
154
+ # Degrade gracefully to local generation when router is flaky or the model is blocked.
155
+ fallback_note = (
156
+ f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
157
+ )
158
+ print(fallback_note)
159
+ text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
160
+ text = f"{text}\n\n{fallback_note}"
161
+ else:
162
+ raise
163
+ return {"messages": [AIMessage(content=text)]}
164
+
165
+ # Return an object with an invoke method to mirror previous agent_executor shape
166
+ return SimpleNamespace(invoke=invoke)
167
+
168
+
169
+ def build_langgraph_agent(
170
+ vectorstore,
171
+ hf_model_id: Optional[str] = None,
172
+ hf_api_token: Optional[str] = None,
173
+ temperature: Optional[float] = None,
174
+ ):
175
+ """
176
+ LangGraph-based RAG agent with retrieval + generation nodes.
177
+ """
178
+ retriever = vectorstore.as_retriever()
179
+ model_id = (hf_model_id or HF_MODEL_ID).strip()
180
+ local_model_id = (LOCAL_MODEL_ID or model_id).strip()
181
+ token = (hf_api_token or HF_API_TOKEN or "").strip() or None
182
+ temp = TEMPERATURE if temperature is None else temperature
183
+
184
+ def retrieve_node(state: ChatState):
185
+ messages = state.get("messages", [])
186
+ user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
187
+ query = user_msg.content if user_msg else ""
188
+
189
+ if hasattr(retriever, "invoke"):
190
+ docs = retriever.invoke(query)
191
+ else:
192
+ docs = retriever.get_relevant_documents(query)
193
+ context = "\n\n".join(d.page_content for d in docs[:4])
194
+ return {"context": context}
195
+
196
+ def generate_node(state: ChatState):
197
+ messages = state.get("messages", [])
198
+ context = state.get("context", "")
199
+ user_msg = next((m for m in reversed(messages) if isinstance(m, HumanMessage)), None)
200
+ question = user_msg.content if user_msg else ""
201
+
202
+ prompt = (
203
+ "You are a helpful assistant. Use the provided context to answer the question. "
204
+ "If the context is insufficient, say you do not know.\n\n"
205
+ f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
206
+ )
207
+
208
+ try:
209
+ if token:
210
+ text = _hf_generate(prompt, model_id=model_id, token=token, temperature=temp)
211
+ else:
212
+ text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
213
+ except Exception as api_err:
214
+ if token:
215
+ fallback_note = (
216
+ f"[Fallback to local model '{local_model_id}' because HF router failed: {api_err}]"
217
+ )
218
+ print(fallback_note)
219
+ text = _local_generate(prompt, model_id=local_model_id, temperature=temp)
220
+ text = f"{text}\n\n{fallback_note}"
221
+ else:
222
+ raise
223
+ return {"messages": messages + [AIMessage(content=text)]}
224
+
225
+ graph = StateGraph(ChatState)
226
+ graph.add_node("retrieve", retrieve_node)
227
+ graph.add_node("generate", generate_node)
228
+ graph.set_entry_point("retrieve")
229
+ graph.add_edge("retrieve", "generate")
230
+ graph.add_edge("generate", END)
231
+
232
+ app = graph.compile()
233
+
234
+ # Wrap to mirror the previous agent_executor interface for Gradio.
235
+ def invoke(payload):
236
+ incoming_messages = payload.get("messages", [])
237
+ initial_state: ChatState = {"messages": incoming_messages, "context": ""}
238
+ return app.invoke(initial_state)
239
+
240
+ return SimpleNamespace(invoke=invoke)
src/config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ # Base Paths
8
+ BASE_DIR = Path(__file__).resolve().parent.parent
9
+ DATA_DIR = BASE_DIR / "data"
10
+ SRC_DIR = BASE_DIR / "src"
11
+
12
+ # Data Paths
13
+ PDF_PATH = DATA_DIR / "source.pdf" # We will rename the input PDF to this
14
+ VECTORSTORE_PATH = DATA_DIR / "faiss_index"
15
+
16
+ # RAG Parameters
17
+ CHUNK_SIZE = 1000
18
+ CHUNK_OVERLAP = 200
19
+ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
20
+
21
+ # LLM Parameters (Hugging Face free Inference API)
22
+ # Default router model should exist on the router. Override via HF_MODEL_ID env var or UI input.
23
+ # Meta Llama 3 8B Instruct is widely available on the HF router as of Nov 2024.
24
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "meta-llama/Meta-Llama-3-8B-Instruct")
25
+ HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "") # Optional for many free endpoints
26
+ LOCAL_MODEL_ID = os.getenv("LOCAL_MODEL_ID", "distilgpt2")
27
+ TEMPERATURE = float(os.getenv("HF_TEMPERATURE", "0.3"))
src/ingestion.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from .config import CHUNK_SIZE, CHUNK_OVERLAP
4
+
5
+ def load_pdf(file_path):
6
+ """
7
+ Loads a PDF file and returns a list of documents.
8
+ """
9
+ loader = PyPDFLoader(file_path)
10
+ documents = loader.load()
11
+ return documents
12
+
13
+ def chunk_documents(documents):
14
+ """
15
+ Splits documents into smaller chunks.
16
+ """
17
+ text_splitter = RecursiveCharacterTextSplitter(
18
+ chunk_size=CHUNK_SIZE,
19
+ chunk_overlap=CHUNK_OVERLAP,
20
+ length_function=len,
21
+ is_separator_regex=False,
22
+ )
23
+ chunks = text_splitter.split_documents(documents)
24
+ return chunks
25
+
26
+ def ingest_file(file_path):
27
+ """
28
+ Orchestrates loading and chunking.
29
+ """
30
+ docs = load_pdf(file_path)
31
+ chunks = chunk_documents(docs)
32
+ print(f"Loaded {len(docs)} pages and created {len(chunks)} chunks.")
33
+ return chunks
src/rag_tool.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+
3
+ def get_retriever_tool(vectorstore):
4
+ """
5
+ Creates a LangChain tool from the vector store retriever.
6
+ """
7
+ retriever = vectorstore.as_retriever()
8
+
9
+ @tool
10
+ def retrieve_rag_docs(query: str) -> str:
11
+ """Search and retrieve information about the RAG Chatbot and LangGraph Agent project from the knowledge base."""
12
+ # Use invoke if available, else get_relevant_documents
13
+ if hasattr(retriever, "invoke"):
14
+ docs = retriever.invoke(query)
15
+ else:
16
+ docs = retriever.get_relevant_documents(query)
17
+
18
+ return "\n\n".join([d.page_content for d in docs])
19
+
20
+ return retrieve_rag_docs
src/vectorstore.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_community.vectorstores import FAISS
3
+ try:
4
+ # Preferred newer package
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ except ImportError:
7
+ # Fallback to older location if extra package is missing
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+ from .config import EMBEDDING_MODEL_NAME, VECTORSTORE_PATH
10
+
11
+ def get_embeddings():
12
+ """
13
+ Initializes the embedding model.
14
+ """
15
+ return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
16
+
17
+ def create_vectorstore(chunks):
18
+ """
19
+ Creates a FAISS vector store from chunks and saves it locally.
20
+ """
21
+ embeddings = get_embeddings()
22
+ vectorstore = FAISS.from_documents(chunks, embeddings)
23
+ vectorstore.save_local(str(VECTORSTORE_PATH))
24
+ return vectorstore
25
+
26
+ def load_vectorstore():
27
+ """
28
+ Loads the FAISS vector store from disk.
29
+ """
30
+ embeddings = get_embeddings()
31
+ if os.path.exists(VECTORSTORE_PATH):
32
+ return FAISS.load_local(str(VECTORSTORE_PATH), embeddings, allow_dangerous_deserialization=True)
33
+ return None
tests/test_pipeline.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+
5
+ # Add project root to sys.path
6
+ sys.path.append(str(Path(__file__).resolve().parent.parent))
7
+
8
+ from src.ingestion import load_pdf, chunk_documents
9
+ from src.vectorstore import create_vectorstore, load_vectorstore
10
+ from src.config import PDF_PATH
11
+
12
+ def test_ingestion():
13
+ print("Testing Ingestion...")
14
+ if not os.path.exists(PDF_PATH):
15
+ print(f"Skipping ingestion test: {PDF_PATH} not found.")
16
+ return
17
+
18
+ docs = load_pdf(str(PDF_PATH))
19
+ assert len(docs) > 0, "No documents loaded"
20
+ print(f"Loaded {len(docs)} pages.")
21
+
22
+ chunks = chunk_documents(docs)
23
+ assert len(chunks) > 0, "No chunks created"
24
+ print(f"Created {len(chunks)} chunks.")
25
+ return chunks
26
+
27
+ def test_vectorstore(chunks):
28
+ print("Testing Vector Store...")
29
+ if not chunks:
30
+ print("Skipping vector store test: No chunks.")
31
+ return
32
+
33
+ vs = create_vectorstore(chunks)
34
+ assert vs is not None, "Vector store creation failed"
35
+ print("Vector store created and saved.")
36
+
37
+ loaded_vs = load_vectorstore()
38
+ assert loaded_vs is not None, "Vector store loading failed"
39
+ print("Vector store loaded successfully.")
40
+
41
+ if __name__ == "__main__":
42
+ try:
43
+ chunks = test_ingestion()
44
+ test_vectorstore(chunks)
45
+ print("All tests passed!")
46
+ except Exception as e:
47
+ print(f"Test failed: {e}")
48
+ sys.exit(1)