Rehman1603 commited on
Commit
f8f842e
·
verified ·
1 Parent(s): 9c996d7

Upload llm_model.py

Browse files
Files changed (1) hide show
  1. llm_model.py +133 -0
llm_model.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.vectorstores import FAISS
2
+ #from langchain.llms import GooglePalm, CTransformers
3
+ from langchain.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
4
+ from langchain.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.chains import RetrievalQA
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from huggingface_hub import InferenceClient
9
+ from langdetect import detect # Language detection
10
+ import os
11
+ from dotenv import load_dotenv
12
+
13
+ vector_index_path = "assets/vectordb/faiss_index"
14
+
15
+ class LlmModel:
16
+
17
+ def __init__(self):
18
+ # load dot env variables
19
+ self.load_env_variables()
20
+ # load llm model
21
+ self.hf_embeddings = self.load_huggingface_embeddings()
22
+
23
+ def load_env_variables(self):
24
+ load_dotenv() # take environment variables from .env
25
+ def detect_language(self, text):
26
+ try:
27
+ return detect(text)
28
+ except:
29
+ return "en" # Default to English if detection fails
30
+
31
+ def generate_response(self, question, context):
32
+ language = self.detect_language(question)
33
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
34
+ inputs = {
35
+ "inputs": {
36
+ "question": question,
37
+ "context": context,
38
+ }
39
+ }
40
+
41
+ def custom_prompt(self, question, history, context):
42
+ #RAG prompt template
43
+ prompt = "<s>"
44
+ for user_prompt, bot_response in history: # provide chat history
45
+ prompt += f"[INST] {user_prompt} [/INST]"
46
+ prompt += f" {bot_response}</s>"
47
+
48
+ message_prompt = f"""
49
+ You are a question answer agent and you must strictly follow below prompt template.
50
+ Given the following context and a question, generate an answer based on this context only.
51
+ Keep answers brief and well-structured. Do not give one word answers.
52
+ If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.
53
+
54
+ CONTEXT: {context}
55
+
56
+ QUESTION: {question}
57
+ """
58
+ prompt += f"[INST] {message_prompt} [/INST]"
59
+
60
+ return prompt
61
+
62
+ def format_sources(self, sources):
63
+ # format the document sources
64
+ source_results = []
65
+ for source in sources:
66
+ source_results.append(str(source.page_content) +
67
+ "\n Document: " + str(source.metadata['source']) +
68
+ " Page: " + str(source.metadata['page']))
69
+ return source_results
70
+
71
+ def mixtral_chat_inference(self, prompt, history, temperature, top_p, repetition_penalty, retriever):
72
+
73
+ context = retriever.get_relevant_documents(prompt)
74
+ sources = self.format_sources(context)
75
+ # use hugging face infrence api
76
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.1",
77
+ token=os.environ["HF_TOKEN"]
78
+ )
79
+ temperature = float(temperature)
80
+ if temperature < 1e-2:
81
+ temperature = 1e-2
82
+
83
+ generate_kwargs = dict(
84
+ temperature = temperature,
85
+ max_new_tokens = 512,
86
+ top_p = top_p,
87
+ repetition_penalty = repetition_penalty,
88
+ do_sample = True
89
+ )
90
+
91
+ formatted_prompt = self.custom_prompt(prompt, history, context)
92
+
93
+ return client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False), sources
94
+
95
+
96
+
97
+ def load_huggingface_embeddings(self):
98
+ # Initialize instructor embeddings using the Hugging Face model
99
+ #return HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-large")
100
+ return HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2",
101
+ model_kwargs={'device': 'cpu'})
102
+
103
+
104
+
105
+ def create_vector_db(self, filename):
106
+
107
+ if filename.endswith(".pdf"):
108
+ loader = PyPDFLoader(file_path=filename)
109
+ elif filename.endswith(".doc") or filename.endswith(".docx"):
110
+ loader = Docx2txtLoader(filename)
111
+ elif filename.endswith("txt") or filename.endswith("TXT"):
112
+ loader = TextLoader(filename)
113
+
114
+ # Split documents
115
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
116
+ splits = text_splitter.split_documents(loader.load())
117
+
118
+ # Check if splits list is empty
119
+ if not splits:
120
+ raise ValueError('No content to index. The document may be empty or not properly formatted.')
121
+
122
+ # Create a FAISS instance for vector database from 'data'
123
+ vectordb = FAISS.from_documents(documents = splits,
124
+ embedding = self.hf_embeddings)
125
+
126
+ # Save vector database locally
127
+ #vectordb.save_local(vector_index_path)
128
+
129
+ # set vectordb content
130
+ # Load the vector database from the local folder
131
+ #vectordb = FAISS.load_local(vector_index_path, self.hf_embeddings)
132
+ # Create a retriever for querying the vector database
133
+ return vectordb.as_retriever(search_type="similarity")