Spaces:
Runtime error
Runtime error
| # trainer_manager.py | |
| from longtrainer.trainer import LongTrainer | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from config import CONNECTION_STRING, CHATGROQ_API_KEY, CUSTOM_PROMPT | |
| def get_embeddings(): | |
| # Initialize HuggingFace embeddings with the specified model and parameters | |
| model_name = "BAAI/bge-small-en" | |
| model_kwargs = {"device": "cpu"} | |
| encode_kwargs = {"normalize_embeddings": True} | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs | |
| ) | |
| return embeddings | |
| def get_llm(): | |
| if not CHATGROQ_API_KEY: | |
| raise ValueError("CHATGROQ_API_KEY is not set.") | |
| llm = ChatGroq( | |
| model="llama-3.3-70b-versatile", | |
| temperature=0, | |
| max_tokens=1024, | |
| api_key=CHATGROQ_API_KEY | |
| ) | |
| return llm | |
| embedding_model = get_embeddings() | |
| llm = get_llm() | |
| # Create a global LongTrainer instance | |
| trainer_instance = LongTrainer( | |
| mongo_endpoint=CONNECTION_STRING, | |
| llm=llm, | |
| embedding_model=embedding_model, | |
| encrypt_chats=True | |
| ) | |
| def get_trainer(): | |
| return trainer_instance | |