Zubaish commited on
Commit
a42513a
·
1 Parent(s): 81345e2

Frontend: robust answer + status handling

Browse files
Files changed (3) hide show
  1. app.py +8 -9
  2. frontend/index.html +21 -67
  3. rag.py +50 -75
app.py CHANGED
@@ -1,22 +1,21 @@
 
1
  from fastapi import FastAPI
2
- from fastapi.responses import HTMLResponse
3
- from fastapi.staticfiles import StaticFiles
4
  from pydantic import BaseModel
5
  from rag import ask_rag_with_status
6
 
7
  app = FastAPI()
8
 
9
- app.mount("/frontend", StaticFiles(directory="frontend"), name="frontend")
10
-
11
  class Query(BaseModel):
12
  question: str
13
 
14
- @app.get("/", response_class=HTMLResponse)
15
- def home():
16
- with open("frontend/index.html", "r", encoding="utf-8") as f:
17
- return f.read()
18
 
19
  @app.post("/chat")
20
  def chat(q: Query):
21
  answer, status = ask_rag_with_status(q.question)
22
- return {"answer": answer, "status": status}
 
 
 
 
1
+ # app.py
2
  from fastapi import FastAPI
 
 
3
  from pydantic import BaseModel
4
  from rag import ask_rag_with_status
5
 
6
  app = FastAPI()
7
 
 
 
8
  class Query(BaseModel):
9
  question: str
10
 
11
+ @app.get("/")
12
+ def health():
13
+ return {"status": "ok"}
 
14
 
15
  @app.post("/chat")
16
  def chat(q: Query):
17
  answer, status = ask_rag_with_status(q.question)
18
+ return {
19
+ "answer": answer,
20
+ "status": status,
21
+ }
frontend/index.html CHANGED
@@ -1,67 +1,21 @@
1
- <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <meta charset="UTF-8" />
5
- <title>HubRAG</title>
6
- <style>
7
- body {
8
- font-family: sans-serif;
9
- max-width: 800px;
10
- margin: 40px auto;
11
- }
12
- textarea {
13
- width: 100%;
14
- padding: 10px;
15
- }
16
- button {
17
- margin-top: 10px;
18
- padding: 8px 16px;
19
- }
20
- pre {
21
- background: #f5f5f5;
22
- padding: 10px;
23
- white-space: pre-wrap;
24
- }
25
- </style>
26
- </head>
27
- <body>
28
-
29
- <h2>📄 HubRAG (HF Space)</h2>
30
-
31
- <textarea id="q" rows="4" placeholder="Ask a question about the documents..."></textarea>
32
- <br/>
33
- <button onclick="ask()">Ask</button>
34
-
35
- <h3>Status</h3>
36
- <ul id="status"></ul>
37
-
38
- <h3>Answer</h3>
39
- <pre id="answer"></pre>
40
-
41
- <script>
42
- async function ask() {
43
- const q = document.getElementById("q").value;
44
- document.getElementById("answer").textContent = "Thinking...";
45
- document.getElementById("status").innerHTML = "";
46
-
47
- const res = await fetch("/ask", { // <-- ensure this matches backend
48
- method: "POST",
49
- headers: { "Content-Type": "application/json" },
50
- body: JSON.stringify({ question: q })
51
- });
52
-
53
- const data = await res.json();
54
-
55
- document.getElementById("answer").textContent =
56
- data.answer || "No answer";
57
-
58
- (data.status || []).forEach(s => {
59
- const li = document.createElement("li");
60
- li.textContent = s;
61
- document.getElementById("status").appendChild(li);
62
- });
63
- }
64
- </script>
65
-
66
- </body>
67
- </html>
 
1
+ # app.py
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from rag import ask_rag_with_status
5
+
6
+ app = FastAPI()
7
+
8
+ class Query(BaseModel):
9
+ question: str
10
+
11
+ @app.get("/")
12
+ def health():
13
+ return {"status": "ok"}
14
+
15
+ @app.post("/chat")
16
+ def chat(q: Query):
17
+ answer, status = ask_rag_with_status(q.question)
18
+ return {
19
+ "answer": answer,
20
+ "status": status,
21
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag.py CHANGED
@@ -1,64 +1,33 @@
1
- import os
2
- from typing import List
3
-
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
- from langchain_community.document_loaders import PyPDFLoader
6
- from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_community.vectorstores import Chroma
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
-
10
  from config import (
11
- KB_DIR,
12
- VECTOR_DB_DIR,
13
  EMBEDDING_MODEL,
14
  LLM_MODEL,
 
 
15
  )
 
16
 
17
- # --------------------------------------------------
18
- # Embeddings (CPU-safe)
19
- # --------------------------------------------------
20
  embeddings = HuggingFaceEmbeddings(
21
  model_name=EMBEDDING_MODEL
22
  )
23
 
24
- # --------------------------------------------------
25
- # Load PDFs (if any)
26
- # --------------------------------------------------
27
- documents = []
28
-
29
- if os.path.exists(KB_DIR):
30
- for file in os.listdir(KB_DIR):
31
- if file.lower().endswith(".pdf"):
32
- loader = PyPDFLoader(os.path.join(KB_DIR, file))
33
- documents.extend(loader.load())
34
-
35
- # --------------------------------------------------
36
- # Split documents
37
- # --------------------------------------------------
38
- splitter = RecursiveCharacterTextSplitter(
39
- chunk_size=500,
40
- chunk_overlap=50
41
- )
42
-
43
- splits = splitter.split_documents(documents) if documents else []
44
-
45
- # --------------------------------------------------
46
- # Vector DB (ONLY if docs exist)
47
- # --------------------------------------------------
48
- vectordb = None
49
- retriever = None
50
-
51
- if splits:
52
- vectordb = Chroma.from_documents(
53
- splits,
54
- embedding=embeddings,
55
- persist_directory=VECTOR_DB_DIR
56
  )
57
- retriever = vectordb.as_retriever(search_kwargs={"k": 3})
 
58
 
59
- # --------------------------------------------------
60
- # Load LLM (CPU ONLY, NO ACCELERATE)
61
- # --------------------------------------------------
62
  tokenizer = AutoTokenizer.from_pretrained(
63
  LLM_MODEL,
64
  trust_remote_code=True
@@ -66,36 +35,35 @@ tokenizer = AutoTokenizer.from_pretrained(
66
 
67
  model = AutoModelForCausalLM.from_pretrained(
68
  LLM_MODEL,
69
- trust_remote_code=True
 
 
70
  )
71
 
72
- llm = pipeline(
73
- "text-generation",
74
- model=model,
75
- tokenizer=tokenizer,
76
- max_new_tokens=256,
77
- do_sample=False
78
- )
79
 
80
- # --------------------------------------------------
81
- # Public RAG API
82
- # --------------------------------------------------
83
- def ask_rag_with_status(question: str):
84
  status = []
85
 
86
- if retriever is None:
87
- return {
88
- "answer": "❌ Knowledge base is empty. Please upload PDFs to the dataset or storage.",
89
- "status": ["⚠️ No documents indexed"]
90
- }
91
 
92
- status.append("🔍 Retrieving documents...")
93
- docs = retriever.get_relevant_documents(question)
 
 
 
 
 
94
 
95
  context = "\n\n".join(d.page_content for d in docs)
 
96
 
97
  prompt = f"""
98
- Use the following context to answer the question.
 
99
 
100
  Context:
101
  {context}
@@ -103,13 +71,20 @@ Context:
103
  Question:
104
  {question}
105
 
106
- Answer clearly and concisely.
107
  """
108
 
109
- status.append("🧠 Generating answer...")
110
- result = llm(prompt)[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
111
 
112
- return {
113
- "answer": result.strip(),
114
- "status": status
115
- }
 
1
+ # rag.py
2
+ from typing import List, Tuple
 
 
 
 
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
  from config import (
 
 
7
  EMBEDDING_MODEL,
8
  LLM_MODEL,
9
+ CHROMA_DIR,
10
+ TOP_K,
11
  )
12
+ import torch
13
 
14
+
15
+ # --- Embeddings ---
 
16
  embeddings = HuggingFaceEmbeddings(
17
  model_name=EMBEDDING_MODEL
18
  )
19
 
20
+ # --- Vector DB (safe load) ---
21
+ try:
22
+ vectordb = Chroma(
23
+ persist_directory=CHROMA_DIR,
24
+ embedding_function=embeddings,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
+ except Exception:
27
+ vectordb = None
28
 
29
+
30
+ # --- LLM ---
 
31
  tokenizer = AutoTokenizer.from_pretrained(
32
  LLM_MODEL,
33
  trust_remote_code=True
 
35
 
36
  model = AutoModelForCausalLM.from_pretrained(
37
  LLM_MODEL,
38
+ trust_remote_code=True,
39
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
+ device_map="auto",
41
  )
42
 
 
 
 
 
 
 
 
43
 
44
+ def ask_rag_with_status(question: str) -> Tuple[str, List[str]]:
 
 
 
45
  status = []
46
 
47
+ if not vectordb:
48
+ return (
49
+ "⚠️ Knowledge base is not loaded yet. Upload documents first.",
50
+ ["Vector DB not initialized"],
51
+ )
52
 
53
+ docs = vectordb.similarity_search(question, k=TOP_K)
54
+
55
+ if not docs:
56
+ return (
57
+ "⚠️ I could not find relevant information in the knowledge base.",
58
+ ["No documents retrieved"],
59
+ )
60
 
61
  context = "\n\n".join(d.page_content for d in docs)
62
+ status.append(f"Retrieved {len(docs)} chunks")
63
 
64
  prompt = f"""
65
+ You are a helpful assistant.
66
+ Answer ONLY using the context below.
67
 
68
  Context:
69
  {context}
 
71
  Question:
72
  {question}
73
 
74
+ Answer:
75
  """
76
 
77
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
78
+
79
+ with torch.no_grad():
80
+ output = model.generate(
81
+ **inputs,
82
+ max_new_tokens=256,
83
+ do_sample=True,
84
+ temperature=0.7,
85
+ )
86
+
87
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
88
+ answer = answer.split("Answer:")[-1].strip()
89
 
90
+ return answer, status