Shouvik599 commited on
Commit ·
da91d6e
1
Parent(s): 0848bca
replaced gemini with nvidia model
Browse files- rag_chain.py +6 -6
rag_chain.py
CHANGED
|
@@ -21,13 +21,12 @@ Returns a dict with:
|
|
| 21 |
import os
|
| 22 |
from dotenv import load_dotenv
|
| 23 |
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
| 24 |
-
from langchain_google_genai import
|
| 25 |
from langchain_chroma import Chroma
|
| 26 |
from langchain_core.prompts import ChatPromptTemplate
|
| 27 |
from langchain_core.output_parsers import StrOutputParser
|
| 28 |
load_dotenv()
|
| 29 |
|
| 30 |
-
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 31 |
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
|
| 32 |
CHROMA_DB_PATH = os.getenv("CHROMA_DB_PATH", "./chroma_db")
|
| 33 |
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "sacred_texts")
|
|
@@ -150,11 +149,12 @@ def build_chain():
|
|
| 150 |
embeddings = get_embeddings()
|
| 151 |
vector_store = get_vector_store(embeddings)
|
| 152 |
|
| 153 |
-
llm =
|
| 154 |
-
model="
|
| 155 |
-
|
| 156 |
temperature=0.2,
|
| 157 |
-
|
|
|
|
| 158 |
)
|
| 159 |
|
| 160 |
prompt = ChatPromptTemplate.from_messages([
|
|
|
|
| 21 |
import os
|
| 22 |
from dotenv import load_dotenv
|
| 23 |
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
|
| 24 |
+
from langchain_google_genai import ChatNVIDIA
|
| 25 |
from langchain_chroma import Chroma
|
| 26 |
from langchain_core.prompts import ChatPromptTemplate
|
| 27 |
from langchain_core.output_parsers import StrOutputParser
|
| 28 |
load_dotenv()
|
| 29 |
|
|
|
|
| 30 |
NVIDIA_API_KEY = os.getenv("NVIDIA_API_KEY")
|
| 31 |
CHROMA_DB_PATH = os.getenv("CHROMA_DB_PATH", "./chroma_db")
|
| 32 |
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "sacred_texts")
|
|
|
|
| 149 |
embeddings = get_embeddings()
|
| 150 |
vector_store = get_vector_store(embeddings)
|
| 151 |
|
| 152 |
+
llm = ChatNVIDIA(
|
| 153 |
+
model="meta/llama-3.3-70b-instruct",
|
| 154 |
+
api_key=NVIDIA_API_KEY,
|
| 155 |
temperature=0.2,
|
| 156 |
+
top_p=0.7,
|
| 157 |
+
max_output_tokens=2048,
|
| 158 |
)
|
| 159 |
|
| 160 |
prompt = ChatPromptTemplate.from_messages([
|