Darayut commited on
Commit
3298436
·
verified ·
1 Parent(s): 9ff648d

Upload simple_rag.py

Browse files
Files changed (1) hide show
  1. src/simple_rag.py +121 -0
src/simple_rag.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified RAG Pipeline for General Document Q&A (Khmer & English)
2
+
3
+ import os
4
+ import logging
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModel
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.schema import Document
9
+ from langchain.vectorstores.chroma import Chroma
10
+ from langchain.embeddings import HuggingFaceEmbeddings
11
+ from langchain.document_loaders import PyPDFDirectoryLoader
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+
15
+ use_gpu = torch.cuda.is_available()
16
+ model_id = "aisingapore/Llama-SEA-LION-v3.5-8B-R"
17
+
18
+
19
+ # # Load model and tokenizer
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ model_id,
23
+ load_in_8bit=True,
24
+ device_map={"": "cpu"}, # Force CPU
25
+ llm_int8_enable_fp32_cpu_offload=True, # Enable CPU offloading
26
+ )
27
+
28
+ pipe = pipeline(
29
+ "text-generation",
30
+ model=model,
31
+ tokenizer=tokenizer,
32
+ )
33
+
34
+ DATA_PATH = "./data/"
35
+ CHROMA_PATH = "chroma"
36
+ embedding_model = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-base")
37
+
38
+ # Generic assistant prompt for dual Khmer/English
39
+ PROMPT_TEMPLATE = """
40
+ You are a helpful assistant.
41
+ Answer the question based ONLY on the context below.
42
+ If the user asks in Khmer, respond in Khmer.
43
+ If the user asks in English, respond in English.
44
+ Use clear, concise sentences. Do not mention the existence of context.
45
+
46
+ Context:
47
+ {context}
48
+
49
+ Question:
50
+ {question}
51
+
52
+ Answer:
53
+ """.strip()
54
+
55
+ def load_documents():
56
+ loader = PyPDFDirectoryLoader(DATA_PATH)
57
+ return loader.load()
58
+
59
+ def split_text(documents: list[Document]):
60
+ splitter = RecursiveCharacterTextSplitter(
61
+ chunk_size=512, chunk_overlap=100, length_function=len, add_start_index=True
62
+ )
63
+ chunks = splitter.split_documents(documents)
64
+ logging.info(f"Split {len(documents)} documents into {len(chunks)} chunks.")
65
+ return chunks
66
+
67
+ def save_to_chroma(chunks: list[Document]):
68
+ if os.path.exists(CHROMA_PATH):
69
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_model)
70
+ db.add_documents(chunks)
71
+ logging.info("Added documents to existing Chroma DB.")
72
+ else:
73
+ db = Chroma.from_documents(
74
+ chunks, embedding_model, persist_directory=CHROMA_PATH
75
+ )
76
+ logging.info("Created new Chroma DB.")
77
+ db.persist()
78
+ logging.info(f"Saved {len(chunks)} chunks to Chroma.")
79
+
80
+ def generate_data_store():
81
+ documents = load_documents()
82
+ chunks = split_text(documents)
83
+ save_to_chroma(chunks)
84
+
85
+ def ask_question(query_text: str, k: int = 3):
86
+ logging.info("Processing user question...")
87
+
88
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_model)
89
+ results = db.similarity_search(query_text, k=k)
90
+
91
+ context_chunks = []
92
+ for doc in results:
93
+ meta = doc.metadata or {}
94
+ context_chunks.append({
95
+ "filename": os.path.basename(meta.get("source", "unknown.pdf")),
96
+ "page": meta.get("page", 1),
97
+ "text": doc.page_content.strip()
98
+ })
99
+
100
+ context_text = "\n\n".join(chunk["text"] for chunk in context_chunks)
101
+ prompt = PROMPT_TEMPLATE.format(context=context_text, question=query_text)
102
+
103
+ messages = [{"role": "user", "content": prompt}]
104
+ logging.info("Sending prompt to model...")
105
+ prompt = tokenizer.apply_chat_template(
106
+ messages,
107
+ add_generation_prompt=True,
108
+ tokenize=False,
109
+ thinking_mode="off"
110
+ )
111
+
112
+ output = pipeline(
113
+ prompt,
114
+ max_new_tokens=1024,
115
+ return_full_text=False,
116
+ truncation=True,
117
+ do_sample=False,
118
+ )
119
+
120
+ answer = output[0]["generated_text"].strip()
121
+ return answer, context_chunks