dev-models's picture
Initial commit
e97c8d1
raw
history blame contribute delete
771 Bytes
import torch
from sentence_transformers import SentenceTransformer
from groq import Groq
from config import CLIP_MODEL_NAME, GROQ_API_KEY, LLM_MODEL_NAME
from langchain_groq import ChatGroq
def get_clip_model(model_name: str = CLIP_MODEL_NAME):
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
model = SentenceTransformer(model_name, trust_remote_code=True)
model.to(device)
return model
except Exception as e:
print(f"Fallback CLIP model due to: {e}")
return SentenceTransformer('clip-ViT-B-32')
def get_llm(model_name: str = LLM_MODEL_NAME):
return ChatGroq(model=model_name, api_key=GROQ_API_KEY, temperature=0.1)
def get_groq_client(api_key: str = GROQ_API_KEY):
return Groq(api_key=api_key)