File size: 2,447 Bytes
cc65c1f
 
 
de0d049
cc65c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237763b
cc65c1f
 
 
 
 
de0d049
cc65c1f
de0d049
 
 
 
 
 
cc65c1f
de0d049
cc65c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de0d049
cc65c1f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from transformers import pipeline , AutoTokenizer
from pymilvus import MilvusClient
import os 
# from redis import Redis
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
from pymilvus import Collection
from groq import Groq

from transformers.pipelines import Pipeline
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from logging import Logger , getLogger , basicConfig , INFO , FileHandler , Formatter

def load_logger() -> Logger : 
    
    if not os.path.exists('assets/logs') : os.makedirs('assets/logs')

    logger : Logger = getLogger(__name__) 

    basicConfig(
        level = INFO , 
        format = '%(asctime)s - %(levelname)s - %(message)s'
    )
    
    file_handler = FileHandler('assets/logs/log.log')
    file_handler.setLevel(INFO) 
    file_handler.setFormatter(Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    
    logger.addHandler(file_handler)

    return logger 

def load_tokenizer(model_name = 'meta-llama/Meta-Llama-3-8B') -> PreTrainedTokenizerFast : 

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    return tokenizer

def load_milvus_client() -> MilvusClient: 
    
    db_name = os.getenv('MILVUS_DB_NAME' , 'assets/database/vectordb/demo.db')

    milvus_client = MilvusClient(db_name)

    return milvus_client

# def load_redis_client(db_name : int) -> Redis : 

#     redis_client = Redis(
#         host = os.getenv('REDIS_HOST' , 'localhost') ,  
#         port = int(os.getenv('REDIS_PORT' , 6379)) , 
#         db = db_name  , 
#         decode_responses = True
#     )
    
#     return redis_client

def load_embedding_model(model_name = '') -> SentenceTransformer : 
    
    if not model_name : model_name = os.getenv('EMBEDDING_MODEL_NAME' , 'all-MiniLM-L6-v2')
    
    embedding_model = SentenceTransformer(model_name)
    
    return embedding_model

def load_gemini_client() : 
    
    gemini_api_key = os.getenv('GEMINI_API_KEY' , '')
    
    if not gemini_api_key : model = ''
    else : 
        genai.configure(api_key = '<Enter the Gemini API Key here>') # ! Can deploy a Llama 3.2 Model and use that instead, which can increase speed and avoid rate limits and increase safety as well
        model = genai.GenerativeModel('gemini-1.5-flash')
        
    return model

def load_groq_client() -> Groq : 

    groq_client = Groq(api_key = os.getenv('GROQ_API_KEY'))
    
    return groq_client