from transformers import pipeline , AutoTokenizer from pymilvus import MilvusClient import os # from redis import Redis from sentence_transformers import SentenceTransformer import google.generativeai as genai from pymilvus import Collection from groq import Groq from transformers.pipelines import Pipeline from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from logging import Logger , getLogger , basicConfig , INFO , FileHandler , Formatter def load_logger() -> Logger : if not os.path.exists('assets/logs') : os.makedirs('assets/logs') logger : Logger = getLogger(__name__) basicConfig( level = INFO , format = '%(asctime)s - %(levelname)s - %(message)s' ) file_handler = FileHandler('assets/logs/log.log') file_handler.setLevel(INFO) file_handler.setFormatter(Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(file_handler) return logger def load_tokenizer(model_name = 'meta-llama/Meta-Llama-3-8B') -> PreTrainedTokenizerFast : tokenizer = AutoTokenizer.from_pretrained(model_name) return tokenizer def load_milvus_client() -> MilvusClient: db_name = os.getenv('MILVUS_DB_NAME' , 'assets/database/vectordb/demo.db') milvus_client = MilvusClient(db_name) return milvus_client # def load_redis_client(db_name : int) -> Redis : # redis_client = Redis( # host = os.getenv('REDIS_HOST' , 'localhost') , # port = int(os.getenv('REDIS_PORT' , 6379)) , # db = db_name , # decode_responses = True # ) # return redis_client def load_embedding_model(model_name = '') -> SentenceTransformer : if not model_name : model_name = os.getenv('EMBEDDING_MODEL_NAME' , 'all-MiniLM-L6-v2') embedding_model = SentenceTransformer(model_name) return embedding_model def load_gemini_client() : gemini_api_key = os.getenv('GEMINI_API_KEY' , '') if not gemini_api_key : model = '' else : genai.configure(api_key = '') # ! Can deploy a Llama 3.2 Model and use that instead, which can increase speed and avoid rate limits and increase safety as well model = genai.GenerativeModel('gemini-1.5-flash') return model def load_groq_client() -> Groq : groq_client = Groq(api_key = os.getenv('GROQ_API_KEY')) return groq_client