File size: 4,836 Bytes
0c6d13f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import pandas as pd
import transformers
from pynvml import *
import torch
from langchain import hub
from model_ret import zephyr_model,llama_model,mistral_model,phi_model,flant5_model
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from create_retriever import ensemble_retriever
# HuggingFace model mapping
hf_model_map = {
    "Zephyr": "HuggingFaceH4/zephyr-7b-beta",
    "Llama": "NousResearch/Meta-Llama-3-8B",
    "Mistral": "unsloth/mistral-7b-instruct-v0.3",
    "Phi": "microsoft/Phi-3-mini-4k-instruct",
    "Flant5": "google/flan-t5-base"
}

# Model chain class
class model_chain:
    model_name = ""

    def __init__(self, 
                 model_name_local, 
                 model_name_online="Llama", 
                 use_local=True, 
                 embedding_name="BAAI/bge-base-en-v1.5", 
                 splitter_type_dropdown="character", 
                 chunk_size_slider=512, 
                 chunk_overlap_slider=30, 
                 separator_textbox="\n", 
                 max_tokens_slider=2048) -> None:
        if use_local:
            quantization, self.model_name = model_name_local.split("_")[0], model_name_local.split("_")[1]
            model_name_temp = model_name_local
        else:
            self.model_name = model_name_online
            model_name_temp = hf_model_map[model_name_online]

        if self.model_name == "Zephyr":
            self.llm = zephyr_model(model_name_temp, quantization, use_local=use_local)
        elif self.model_name == "Llama":
            self.llm = llama_model(model_name_temp, quantization, use_local=use_local)
        elif self.model_name == "Mistral":
            self.llm = mistral_model(model_name_temp, quantization, use_local=use_local)
        elif self.model_name == "Phi":
            self.llm = phi_model(model_name_temp, quantization, use_local=use_local)
        elif self.model_name == "Flant5":
            self.tokenizer, self.model, self.llm = flant5_model(model_name_temp, use_local=use_local)

        # Creating the retriever
        self.retriever = ensemble_retriever(embedding_name,
                                            splitter_type=splitter_type_dropdown,
                                            chunk_size=chunk_size_slider,
                                            chunk_overlap=chunk_overlap_slider,
                                            separator=separator_textbox,
                                            max_tokens=max_tokens_slider)

        # Defining the RAG chain
        prompt = hub.pull("rlm/rag-prompt")
        self.rag_chain = (
            {"context": self.retriever | self.format_docs, "question": RunnablePassthrough()}
            | prompt
            | self.llm
            | StrOutputParser()
        )

    # Helper function to format documents
    def format_docs(self, docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Retrieve RAG chain
    def rag_chain_ret(self):
        return self.rag_chain

    # Answer retrieval function
    def ans_ret(self, inp, rag_chain):
        if self.model_name == 'Flant5':
            my_question = "What is KUET?"
            data = self.retriever.invoke(inp)
            context = ""
            for x in data[:2]:
                context += (x.page_content) + "\n"
            inputs = f"""Please answer to this question using this context:\n{context}\n{my_question}"""
            inputs = self.tokenizer(inputs, return_tensors="pt")
            outputs = self.model.generate(**inputs)
            answer = self.tokenizer.decode(outputs[0])
            from textwrap import fill
            ans = fill(answer, width=100)
            return ans

        ans = rag_chain.invoke(inp)
        ans = ans.split("Answer:")[1]
        return ans

# def model_push(hf):
#     from transformers import AutoTokenizer, AutoModelForCausalLM
#     if model_name=="Mistral":
#         path="models/full_KUET_LLM_mistral"
#     elif model_name=="Zepyhr":
#         path="models/full_KUET_LLM_zepyhr"
#     elif model_name=="Llama2":
#         path="models/full_KUET_LLM_llama" 
#     tokenizer = AutoTokenizer.from_pretrained(path)
#     model = AutoModelForCausalLM.from_pretrained(path,
#                                                     device_map='auto',
#                                                     torch_dtype=torch.float16,
#                                                     use_auth_token=True,
#                                                     load_in_8bit=True,
#                                                     #  load_in_4bit=True
#                                                     )
#     model.push_to_hub(repo_id=f"My_model",token=hf)
#     tokenizer.push_to_hub(repo_id=f"My_model",token=hf)