Darayut commited on
Commit
2f923a7
·
verified ·
1 Parent(s): 711b3e6

Update src/simple_rag.py

Browse files
Files changed (1) hide show
  1. src/simple_rag.py +49 -60
src/simple_rag.py CHANGED
@@ -1,30 +1,43 @@
1
- # Modified RAG Pipeline for General Document Q&A (Khmer & English)
2
-
3
  import os
4
  import logging
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModel
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain.schema import Document
9
- # Updated imports for LangChain
10
- from langchain_community.vectorstores import Chroma
11
- from langchain_community.embeddings import HuggingFaceEmbeddings
12
- from langchain_community.document_loaders import PyPDFDirectoryLoader
13
 
14
  logging.basicConfig(level=logging.INFO)
15
 
16
  use_gpu = torch.cuda.is_available()
 
 
 
 
 
 
17
  model_id = "aisingapore/Llama-SEA-LION-v3.5-8B-R"
18
 
19
- logging.info(use_gpu)
20
 
21
  # # Load model and tokenizer
22
  tokenizer = AutoTokenizer.from_pretrained(model_id)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- model_id,
25
- load_in_8bit=True, # Quantization
26
- device_map="cpu", # Force CPU
27
- )
 
 
 
 
 
 
 
 
 
 
28
 
29
  pipeline = pipeline(
30
  "text-generation",
@@ -32,29 +45,25 @@ pipeline = pipeline(
32
  tokenizer=tokenizer,
33
  )
34
 
35
- # Use Hugging Face's writable directory
36
- WRITABLE_DIR = os.environ.get("HOME", "/app")
37
-
38
- DATA_PATH = os.path.join(WRITABLE_DIR, "src", "data")
39
- CHROMA_PATH = os.path.join(WRITABLE_DIR, "src", "chroma")
40
-
41
  embedding_model = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-base")
42
 
43
- # PROMPT_TEMPLATE = """
44
- # You are a helpful assistant.
45
- # Answer the question based ONLY on the context below.
46
- # If the user asks in Khmer, respond in Khmer.
47
- # If the user asks in English, respond in English.
48
- # Use clear, concise sentences, no more than 50 word. Do not mention the existence of context.
49
 
50
- # Context:
51
- # {context}
52
 
53
- # Question:
54
- # {question}
55
 
56
- # Answer:
57
- # """
 
 
 
58
 
59
  def load_documents():
60
  loader = PyPDFDirectoryLoader(DATA_PATH)
@@ -62,7 +71,7 @@ def load_documents():
62
 
63
  def split_text(documents: list[Document]):
64
  splitter = RecursiveCharacterTextSplitter(
65
- chunk_size=256, chunk_overlap=50, length_function=len, add_start_index=True
66
  )
67
  chunks = splitter.split_documents(documents)
68
  logging.info(f"Split {len(documents)} documents into {len(chunks)} chunks.")
@@ -102,44 +111,24 @@ def ask_question(query_text: str, k: int = 3):
102
  })
103
 
104
  context_text = "\n\n".join(chunk["text"] for chunk in context_chunks)
105
- #prompt = PROMPT_TEMPLATE.format(context=context_text, question=query_text)
106
- #logging.info(f"Prompt: {prompt}")
107
-
108
- # Construct structured messages instead of using PROMPT_TEMPLATE
109
- messages = [
110
- {
111
- "role": "user",
112
- "content": f"""Base your answer only on the following context:\n\n{context_text}\n\nQuestion: {query_text}\nAnswer:"""
113
- }
114
- ]
115
 
 
 
116
  prompt = tokenizer.apply_chat_template(
117
- messages,
118
- add_generation_prompt=True,
119
- tokenize=False,
120
- thinking_mode="off"
121
- )
122
-
123
- logging.info(f"Prompts: {prompt}")
124
 
125
  output = pipeline(
126
  prompt,
127
  max_new_tokens=128,
128
- do_sample=False,
129
  return_full_text=False,
130
  truncation=True,
 
131
  )
132
 
133
- # output = pipeline(
134
- # messages,
135
- # max_new_tokens=256,
136
- # return_full_text=False,
137
- # truncation=True,
138
- # do_sample=False,
139
- # )
140
-
141
-
142
- logging.info(f"Output: {output}")
143
-
144
  answer = output[0]["generated_text"].strip()
145
  return answer, context_chunks
 
 
 
1
  import os
2
  import logging
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModel
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.schema import Document
7
+ from langchain.vectorstores.chroma import Chroma
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ from langchain.document_loaders import PyPDFDirectoryLoader
10
+ from transformers import BitsAndBytesConfig
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
  use_gpu = torch.cuda.is_available()
15
+
16
+ if use_gpu:
17
+ print("CUDA device in use:", torch.cuda.get_device_name(0))
18
+ else:
19
+ print("Running on CPU. No GPU detected.")
20
+
21
  model_id = "aisingapore/Llama-SEA-LION-v3.5-8B-R"
22
 
 
23
 
24
  # # Load model and tokenizer
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
+
27
+ if use_gpu:
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ device_map="auto",
31
+ load_in_8bit=True,
32
+ torch_dtype=torch.float16,
33
+ )
34
+ else:
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ model_id,
37
+ load_in_8bit=True,
38
+ device_map={"": "cpu"}, # Force CPU
39
+ )
40
+
41
 
42
  pipeline = pipeline(
43
  "text-generation",
 
45
  tokenizer=tokenizer,
46
  )
47
 
48
+ DATA_PATH = "./data/"
49
+ CHROMA_PATH = "chroma"
 
 
 
 
50
  embedding_model = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-base")
51
 
52
+ # Generic assistant prompt for dual Khmer/English
53
+ PROMPT_TEMPLATE = """
54
+ You are a helpful assistant.
55
+ Answer the question based ONLY on the context below.
 
 
56
 
57
+ Use clear, concise sentences, no more than 50 words. Do not mention the existence of context.
 
58
 
59
+ Context:
60
+ {context}
61
 
62
+ Question:
63
+ {question}
64
+
65
+ Answer:
66
+ """.strip()
67
 
68
  def load_documents():
69
  loader = PyPDFDirectoryLoader(DATA_PATH)
 
71
 
72
  def split_text(documents: list[Document]):
73
  splitter = RecursiveCharacterTextSplitter(
74
+ chunk_size=512, chunk_overlap=100, length_function=len, add_start_index=True
75
  )
76
  chunks = splitter.split_documents(documents)
77
  logging.info(f"Split {len(documents)} documents into {len(chunks)} chunks.")
 
111
  })
112
 
113
  context_text = "\n\n".join(chunk["text"] for chunk in context_chunks)
114
+ prompt = PROMPT_TEMPLATE.format(context=context_text, question=query_text)
 
 
 
 
 
 
 
 
 
115
 
116
+ messages = [{"role": "user", "content": prompt}]
117
+ logging.info("Sending prompt to model...")
118
  prompt = tokenizer.apply_chat_template(
119
+ messages,
120
+ add_generation_prompt=True,
121
+ tokenize=False,
122
+ thinking_mode="off"
123
+ )
 
 
124
 
125
  output = pipeline(
126
  prompt,
127
  max_new_tokens=128,
 
128
  return_full_text=False,
129
  truncation=True,
130
+ do_sample=False,
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
133
  answer = output[0]["generated_text"].strip()
134
  return answer, context_chunks