Omkar1872 commited on
Commit
5ea49e4
·
verified ·
1 Parent(s): 4c58750

Create rag_qa.py

Browse files
Files changed (1) hide show
  1. rag_qa.py +48 -0
rag_qa.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import faiss
5
+ import numpy as np
6
+
7
+ # Load embedding model
8
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
9
+
10
+ # Load language model
11
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ torch_dtype=torch.float16,
16
+ device_map="auto"
17
+ )
18
+
19
+ # Optional: Add system instruction
20
+ SYSTEM_PROMPT = "You are an AI assistant helping users understand documents."
21
+
22
+ # Load FAISS index and documents
23
+ def load_faiss_index():
24
+ index = faiss.read_index("vector_index.faiss")
25
+ with open("documents.npy", "rb") as f:
26
+ documents = np.load(f, allow_pickle=True)
27
+ return index, documents
28
+
29
+ # Embed the user query
30
+ def embed_query(query):
31
+ return embedding_model.encode([query])[0]
32
+
33
+ # Retrieve top-k relevant documents
34
+ def retrieve_top_k_docs(query_embedding, index, documents, k=3):
35
+ query_embedding = np.array([query_embedding]).astype("float32")
36
+ scores, indices = index.search(query_embedding, k)
37
+ retrieved_docs = [documents[i] for i in indices[0]]
38
+ return retrieved_docs
39
+
40
+ # Generate the final answer
41
+ def generate_answer(context_docs, user_query):
42
+ context = "\n".join(context_docs)
43
+ prompt = f"<s>[INST] {SYSTEM_PROMPT}\n\nContext:\n{context}\n\nQuestion: {user_query} [/INST]"
44
+
45
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
46
+ output = model.generate(**inputs, max_new_tokens=500, do_sample=True)
47
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
48
+ return answer.split("[/INST]")[-1].strip()