lonardonifabio commited on
Commit
aa7e321
·
1 Parent(s): c4aa941

Upload 4 files

Browse files
Files changed (4) hide show
  1. llm/__init__.py +0 -0
  2. llm/llm.py +10 -0
  3. llm/prompts.py +14 -0
  4. llm/wrapper.py +38 -0
llm/__init__.py ADDED
File without changes
llm/llm.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.llms import CTransformers
2
+
3
+ def setup_llm():
4
+ llm = CTransformers(
5
+ model="models/mistral-7b-instruct-v0.1.Q8_0.gguf",
6
+ model_type="mistral",
7
+ config={"max_new_tokens": 2048, "context_length": 4096, "temperature": 0},
8
+ )
9
+
10
+ return llm
llm/prompts.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note: Precise formatting of spacing and indentation of the prompt template is important,
2
+ # as it is highly sensitive to whitespace changes. For example, it could have problems generating
3
+ # a summary from the pieces of context if the spacing is not done correctly
4
+
5
+ qa_template = """Your role is financial controller.
6
+ You are working on invoice documents.
7
+ Your main work is to extract data from the invoice.
8
+ I would like to extract the following data from the invoices: date, number, sender, final amount and short description of what was purchased.
9
+
10
+ Context: {context}
11
+ Question: {question}
12
+
13
+ Helpful answer:
14
+ """
llm/wrapper.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import box
2
+ import yaml
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain.chains import RetrievalQA
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import FAISS
8
+ from llm.prompts import qa_template
9
+ from llm.llm import setup_llm
10
+
11
+ # Import config vars
12
+ with open('config.yml', 'r', encoding='utf8') as ymlfile:
13
+ cfg = box.Box(yaml.safe_load(ymlfile))
14
+
15
+
16
+ def set_qa_prompt():
17
+ prompt = PromptTemplate(template=qa_template, input_variables=['context', 'question'])
18
+ return prompt
19
+
20
+
21
+ def build_retrieval_qa_chain(llm, prompt, vectordb):
22
+ qa_chain = RetrievalQA.from_chain_type(llm=llm,
23
+ chain_type='stuff',
24
+ retriever=vectordb.as_retriever(search_kwargs={'k': cfg.VECTOR_COUNT}),
25
+ return_source_documents=cfg.RETURN_SOURCE_DOCUMENTS,
26
+ chain_type_kwargs={'prompt': prompt}
27
+ )
28
+ return qa_chain
29
+
30
+
31
+ def setup_qa_chain():
32
+ embeddings = HuggingFaceEmbeddings(model_name=cfg.EMBEDDINGS,model_kwargs={'device': 'cpu'})
33
+ vectordb = FAISS.load_local(cfg.DB_FAISS_PATH, embeddings)
34
+ llm = setup_llm()
35
+ qa_prompt = set_qa_prompt()
36
+ qa_chain = build_retrieval_qa_chain(llm, qa_prompt, vectordb)
37
+
38
+ return qa_chain