Felipe Silva commited on
Commit
eb6c217
·
1 Parent(s): df8b30e

ajuste design pattern

Browse files
Files changed (1) hide show
  1. rag_utils.py +55 -28
rag_utils.py CHANGED
@@ -17,32 +17,46 @@ device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else
17
 
18
  import os
19
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
20
- cache_dir = "/home/user/.cache/huggingface" #"./model/qwen-awq" #"/home/felipe/.cache/huggingface/transformers" #"/home/user/.cache/huggingface"
21
 
22
- embedding_model = HuggingFaceEmbeddings(model_name=config.local_emb_path)
 
 
 
 
 
 
 
 
 
 
23
 
24
  # model_name = "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8" #"Qwen/Qwen2.5-7B-Instruct-AWQ" #"Qwen/Qwen2.5-7B-Instruct"
25
- model = AutoModelForCausalLM.from_pretrained(
26
- config.local_model_path,
27
- torch_dtype="auto",
28
- device_map="auto",
29
- trust_remote_code=True,
30
- # cache_dir=cache_dir
31
- )
32
- model.to(device)
33
- tokenizer = AutoTokenizer.from_pretrained(config.local_model_path, trust_remote_code=True)#, cache_dir=cache_dir)
34
-
35
- pipe = pipeline(
36
- "text-generation",
37
- model=model,
38
- tokenizer=tokenizer,
39
- max_new_tokens=512,
40
- temperature=0.1,
41
- do_sample=False
42
- )
43
 
44
- # Adapta para LangChain
45
- llm = HuggingFacePipeline(pipeline=pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def create_split_doc(raw_text):
48
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
@@ -51,6 +65,7 @@ def create_split_doc(raw_text):
51
  return docs
52
 
53
  def store_docs(docs):
 
54
  vectorstore = FAISS.from_documents(docs, embedding_model)
55
  return vectorstore
56
 
@@ -73,14 +88,26 @@ Pergunta:
73
  return prompt_template
74
 
75
  def create_rag_chain(vectorstore):
 
 
 
 
 
 
 
 
 
 
 
 
76
  rag_chain = RetrievalQA.from_chain_type(
77
- llm=llm,
78
- retriever=vectorstore.as_retriever(),
79
- chain_type="stuff",
80
- chain_type_kwargs={"prompt": create_template()}
81
- )
 
82
  return rag_chain
83
 
84
-
85
  if __name__ == '__main__':
86
  pass
 
17
 
18
  import os
19
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
20
+ # cache_dir = "/home/user/.cache/huggingface" #"./model/qwen-awq" #"/home/felipe/.cache/huggingface/transformers" #"/home/user/.cache/huggingface"
21
 
22
+ _embedding_instance = None
23
+ _model_instance = None
24
+ _tokenizer = None
25
+
26
+ def get_embedding_model():
27
+ global _embedding_instance
28
+ if _embedding_instance is None:
29
+ if config.local_emb_path is None:
30
+ raise ValueError("⚠️ config.local_emb_path ainda não foi inicializado!")
31
+ _embedding_instance = HuggingFaceEmbeddings(model_name=config.local_emb_path)
32
+ return _embedding_instance
33
 
34
  # model_name = "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8" #"Qwen/Qwen2.5-7B-Instruct-AWQ" #"Qwen/Qwen2.5-7B-Instruct"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def get_model():
37
+ global _model_instance
38
+ if _model_instance is None:
39
+ if config.local_model_path is None:
40
+ raise ValueError("⚠️ config.local_model_path ainda não foi inicializado!")
41
+ _model_instance = AutoModelForCausalLM.from_pretrained(
42
+ config.local_model_path,
43
+ torch_dtype="auto",
44
+ device_map="auto",
45
+ trust_remote_code=True
46
+ )
47
+
48
+ return _model_instance
49
+
50
+ # _model_instance.to(device)
51
+
52
+ def get_tokenizer():
53
+ global _tokenizer
54
+ if _tokenizer is None:
55
+ if config.local_model_path is None:
56
+ raise ValueError("⚠️ config.local_model_path ainda não foi inicializado!")
57
+ _tokenizer = AutoTokenizer.from_pretrained(config.local_model_path, trust_remote_code=True)
58
+
59
+ return _tokenizer
60
 
61
  def create_split_doc(raw_text):
62
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
 
65
  return docs
66
 
67
  def store_docs(docs):
68
+ embedding_model = get_embedding_model()
69
  vectorstore = FAISS.from_documents(docs, embedding_model)
70
  return vectorstore
71
 
 
88
  return prompt_template
89
 
90
  def create_rag_chain(vectorstore):
91
+ pipe = pipeline(
92
+ "text-generation",
93
+ model=get_model(),
94
+ tokenizer=get_tokenizer(),
95
+ max_new_tokens=512,
96
+ temperature=0.1,
97
+ do_sample=False
98
+ )
99
+
100
+ # Adapta para LangChain
101
+ llm = HuggingFacePipeline(pipeline=pipe)
102
+
103
  rag_chain = RetrievalQA.from_chain_type(
104
+ llm=llm,
105
+ retriever=vectorstore.as_retriever(),
106
+ chain_type="stuff",
107
+ chain_type_kwargs={"prompt": create_template()}
108
+ )
109
+
110
  return rag_chain
111
 
 
112
  if __name__ == '__main__':
113
  pass