kiranmadhusud commited on
Commit
6237214
Β·
1 Parent(s): 540c55e

fix RAG app

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +42 -27
rag_pipeline.py CHANGED
@@ -1,13 +1,13 @@
1
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
 
 
2
  from sentence_transformers import SentenceTransformer
3
- from langchain_text_splitters import RecursiveCharacterTextSplitter
4
  import faiss
5
  import numpy as np
6
  import torch
7
 
8
- # ── Models (all free, no API key needed) ──────────────────────────────────────
9
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # lightweight embedder
10
- LLM_MODEL = "google/flan-t5-base" # swap for flan-t5-large if on GPU
11
 
12
  class RAGPipeline:
13
  def __init__(self):
@@ -15,35 +15,26 @@ class RAGPipeline:
15
  self.embedder = SentenceTransformer(EMBEDDING_MODEL)
16
 
17
  print("Loading LLM...")
18
- self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
19
- self.model = AutoModelForSeq2SeqLM.from_pretrained(LLM_MODEL)
20
  self.llm = pipeline(
21
- "text2text-generation",
22
- model=self.model,
23
- tokenizer=self.tokenizer,
24
- max_new_tokens=300,
25
- device=0 if torch.cuda.is_available() else -1,
26
  )
27
 
28
  self.index = None
29
  self.chunks = []
30
- self.splitter = RecursiveCharacterTextSplitter(
31
- chunk_size=500, chunk_overlap=50
32
- )
33
 
34
  def ingest(self, text: str):
35
- """Chunk text, embed it, store in FAISS."""
36
- self.chunks = self.splitter.split_text(text)
37
  embeddings = self.embedder.encode(self.chunks, show_progress_bar=False)
38
  embeddings = np.array(embeddings).astype("float32")
39
-
40
  dim = embeddings.shape[1]
41
  self.index = faiss.IndexFlatL2(dim)
42
  self.index.add(embeddings)
43
  return len(self.chunks)
44
 
45
  def retrieve(self, query: str, top_k: int = 3):
46
- """Find the most relevant chunks for a query."""
47
  if self.index is None:
48
  return []
49
  q_emb = self.embedder.encode([query]).astype("float32")
@@ -51,17 +42,41 @@ class RAGPipeline:
51
  return [self.chunks[i] for i in indices[0] if i < len(self.chunks)]
52
 
53
  def answer(self, query: str):
54
- """Full RAG: retrieve β†’ build prompt β†’ generate."""
55
  context_chunks = self.retrieve(query)
56
  if not context_chunks:
57
  return "⚠️ Please upload a document first."
58
 
59
  context = "\n\n".join(context_chunks)
60
- prompt = (
61
- f"Answer the question based only on the context below.\n\n"
62
- f"Context:\n{context}\n\n"
63
- f"Question: {query}\n\n"
64
- f"Answer:"
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
- result = self.llm(prompt)[0]["generated_text"]
67
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # rag_pipeline.py
2
+
3
+ from transformers import pipeline
4
  from sentence_transformers import SentenceTransformer
 
5
  import faiss
6
  import numpy as np
7
  import torch
8
 
9
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
10
+ LLM_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # βœ… works with text-generation
 
11
 
12
  class RAGPipeline:
13
  def __init__(self):
 
15
  self.embedder = SentenceTransformer(EMBEDDING_MODEL)
16
 
17
  print("Loading LLM...")
 
 
18
  self.llm = pipeline(
19
+ "text-generation", # βœ… use this instead
20
+ model=LLM_MODEL,
21
+ torch_dtype=torch.float32,
22
+ device_map="auto" if torch.cuda.is_available() else None,
 
23
  )
24
 
25
  self.index = None
26
  self.chunks = []
 
 
 
27
 
28
  def ingest(self, text: str):
29
+ self.chunks = split_text(text)
 
30
  embeddings = self.embedder.encode(self.chunks, show_progress_bar=False)
31
  embeddings = np.array(embeddings).astype("float32")
 
32
  dim = embeddings.shape[1]
33
  self.index = faiss.IndexFlatL2(dim)
34
  self.index.add(embeddings)
35
  return len(self.chunks)
36
 
37
  def retrieve(self, query: str, top_k: int = 3):
 
38
  if self.index is None:
39
  return []
40
  q_emb = self.embedder.encode([query]).astype("float32")
 
42
  return [self.chunks[i] for i in indices[0] if i < len(self.chunks)]
43
 
44
  def answer(self, query: str):
 
45
  context_chunks = self.retrieve(query)
46
  if not context_chunks:
47
  return "⚠️ Please upload a document first."
48
 
49
  context = "\n\n".join(context_chunks)
50
+
51
+ # TinyLlama uses ChatML format
52
+ prompt = f"""<|system|>
53
+ You are a helpful assistant. Answer only based on the context provided.</s>
54
+ <|user|>
55
+ Context:
56
+ {context}
57
+
58
+ Question: {query}</s>
59
+ <|assistant|>"""
60
+
61
+ result = self.llm(
62
+ prompt,
63
+ max_new_tokens=300,
64
+ do_sample=False,
65
+ temperature=1.0,
66
+ pad_token_id=self.llm.tokenizer.eos_token_id,
67
  )
68
+
69
+ # Strip the prompt β€” return only the generated part
70
+ generated = result[0]["generated_text"]
71
+ answer = generated.split("<|assistant|>")[-1].strip()
72
+ return answer
73
+
74
+
75
+ def split_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list:
76
+ chunks = []
77
+ start = 0
78
+ while start < len(text):
79
+ end = start + chunk_size
80
+ chunks.append(text[start:end])
81
+ start += chunk_size - overlap
82
+ return [c.strip() for c in chunks if c.strip()]