File size: 3,703 Bytes
d0c774c
 
 
 
 
 
 
 
 
 
 
ee5fb3c
f30b648
d0c774c
 
 
 
 
 
 
eb6c217
d0c774c
eb6c217
 
 
 
 
 
 
 
 
81ce4aa
eb6c217
d0c774c
f30b648
d0c774c
c3bd22b
eb6c217
 
 
 
 
 
 
c3bd22b
b62b49f
eb6c217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0c774c
 
 
 
 
 
 
c3bd22b
d0c774c
eb6c217
d0c774c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3bd22b
d0c774c
eb6c217
 
 
 
 
 
d54767c
 
eb6c217
 
 
 
 
d0c774c
eb6c217
 
 
d54767c
eb6c217
 
 
d0c774c
 
 
16f21ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline

from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA

import spaces
import config
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'

import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
# cache_dir = "/home/user/.cache/huggingface" #"./model/qwen-awq" #"/home/felipe/.cache/huggingface/transformers" #"/home/user/.cache/huggingface"

_embedding_instance = None
_model_instance = None
_tokenizer = None

def get_embedding_model():
    global _embedding_instance
    if _embedding_instance is None:
        if config.local_emb_path is None:
            raise ValueError("⚠️ config.local_emb_path ainda não foi inicializado!")
        _embedding_instance = HuggingFaceEmbeddings(model_name=config.local_emb_path, model_kwargs={"device": "cpu"})
    return _embedding_instance

# model_name = "Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8" #"Qwen/Qwen2.5-7B-Instruct-AWQ" #"Qwen/Qwen2.5-7B-Instruct"

# @spaces.GPU
def get_model():
    global _model_instance
    if _model_instance is None:
        if config.local_model_path is None:
            raise ValueError("⚠️ config.local_model_path ainda não foi inicializado!")
        _model_instance = AutoModelForCausalLM.from_pretrained(
            config.local_model_path,
            dtype=torch.float16,
            device_map={"": "cuda"},
            trust_remote_code=True
        )
    
    return _model_instance

# _model_instance.to(device)

def get_tokenizer():
    global _tokenizer
    if _tokenizer is None:
        if config.local_model_path is None:
            raise ValueError("⚠️ config.local_model_path ainda não foi inicializado!")
        _tokenizer = AutoTokenizer.from_pretrained(config.local_model_path, trust_remote_code=True)
    
    return _tokenizer

def create_split_doc(raw_text):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    docs = text_splitter.create_documents([raw_text])

    return docs

# @spaces.GPU
def store_docs(docs):
    embedding_model = get_embedding_model()
    vectorstore = FAISS.from_documents(docs, embedding_model)
    return vectorstore

def create_template():
    prompt_template = PromptTemplate(
        input_variables=["context", "question"],
        template="""
Você é um especialista em extrair informações em documentos.
Com base nas informações a seguir, forneça a melhor resposta.
Caso não tenha certeza da resposta, prefira falar que não sabe responder tal pergunta.
Responda de maneira amigável e clara.

Contexto:
{context}

Pergunta:
{question}
"""
)
    return prompt_template

# @spaces.GPU
def create_rag_chain(vectorstore):
    pipe = pipeline(
        "text-generation",
        model=get_model(),
        tokenizer=get_tokenizer(),
        max_new_tokens=512,
        temperature=0.1,
        do_sample=False,
        return_full_text=False
    )

    # Adapta para LangChain
    llm = HuggingFacePipeline(pipeline=pipe)

    rag_chain = RetrievalQA.from_chain_type(
        llm=llm,
        retriever=vectorstore.as_retriever(),
        chain_type="stuff",
        return_source_documents=True,
        chain_type_kwargs={"prompt": create_template()}
    )
    
    return rag_chain

if __name__ == '__main__':
    pass