SurajJha21 commited on
Commit
d5d16b3
·
verified ·
1 Parent(s): 1feb939

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModel
3
  from langchain_community.document_loaders import WebBaseLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.chains.combine_documents import create_stuff_documents_chain
6
  from langchain_core.prompts import ChatPromptTemplate
7
  from langchain.chains import create_retrieval_chain
8
  from langchain_community.vectorstores import FAISS
 
9
  import numpy as np
10
  import torch
11
  import time
@@ -14,16 +15,18 @@ import time
14
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
15
  model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
16
 
17
- def embed_text(texts):
18
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
19
- with torch.no_grad():
20
- embeddings = model(**inputs).last_hidden_state.mean(dim=1)
21
- return embeddings.numpy()
22
 
23
- def embedding_function(texts):
24
- # This function converts texts to embeddings using the Hugging Face model
25
- embeddings = embed_text(texts)
26
- return embeddings
 
 
 
 
27
 
28
  if "vector" not in st.session_state:
29
  st.session_state.loader = WebBaseLoader("https://docs.nvidia.com/cuda/")
@@ -32,10 +35,10 @@ if "vector" not in st.session_state:
32
  st.session_state.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
33
  documents = st.session_state.text_splitter.split_documents(st.session_state.docs[:50])
34
 
35
- # Create FAISS index using the custom embedding function
36
  st.session_state.vectors = FAISS.from_texts(
37
  [doc.page_content for doc in documents],
38
- embedding_function
39
  )
40
 
41
  st.title("ChatGroq Demo")
 
1
  import streamlit as st
2
+ from langchain_groq import ChatGroq
3
  from langchain_community.document_loaders import WebBaseLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.chains.combine_documents import create_stuff_documents_chain
6
  from langchain_core.prompts import ChatPromptTemplate
7
  from langchain.chains import create_retrieval_chain
8
  from langchain_community.vectorstores import FAISS
9
+ from langchain.embeddings import HuggingFaceEmbeddings
10
  import numpy as np
11
  import torch
12
  import time
 
15
  tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
16
  model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
17
 
18
+ class CustomHuggingFaceEmbeddings(HuggingFaceEmbeddings):
19
+ def __init__(self):
20
+ super().__init__(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
 
21
 
22
+ def embed_documents(self, texts):
23
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
24
+ with torch.no_grad():
25
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1)
26
+ return embeddings.numpy()
27
+
28
+ # Instantiate embeddings class
29
+ embeddings = CustomHuggingFaceEmbeddings()
30
 
31
  if "vector" not in st.session_state:
32
  st.session_state.loader = WebBaseLoader("https://docs.nvidia.com/cuda/")
 
35
  st.session_state.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
36
  documents = st.session_state.text_splitter.split_documents(st.session_state.docs[:50])
37
 
38
+ # Create FAISS index using the custom embeddings class
39
  st.session_state.vectors = FAISS.from_texts(
40
  [doc.page_content for doc in documents],
41
+ embeddings
42
  )
43
 
44
  st.title("ChatGroq Demo")