anshumanpatil commited on
Commit
664007d
Β·
1 Parent(s): 5d6cc94

add env vars

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +8 -17
.gitignore CHANGED
@@ -8,3 +8,4 @@ wheels/
8
 
9
  # Virtual environments
10
  .venv
 
 
8
 
9
  # Virtual environments
10
  .venv
11
+ .env
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import pypdf
4
  import docx2txt
5
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  from langchain_community.vectorstores import FAISS
@@ -9,7 +9,11 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.schema import Document
10
  from sentence_transformers import SentenceTransformer
11
  from langchain_community.document_loaders import DirectoryLoader, TextLoader
 
 
12
 
 
 
13
  # ------------------------------
14
  # Title
15
  # ------------------------------
@@ -20,7 +24,7 @@ st.title("πŸ“š RAG For MSCI Indexes")
20
  # ------------------------------
21
  @st.cache_resource
22
  def load_model():
23
- model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  model = AutoModelForCausalLM.from_pretrained(model_name)
26
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
@@ -29,13 +33,10 @@ with st.spinner("πŸ”„ Loading Model..."):
29
  generator = load_model()
30
 
31
  # ------------------------------
32
- # File Upload
33
  # ------------------------------
34
  uploaded_file = "./msci"
35
 
36
- # ------------------------------
37
- # Extract Text
38
- # ------------------------------
39
  def extract_text(folder_path):
40
  loader = DirectoryLoader(
41
  path=folder_path,
@@ -44,9 +45,6 @@ def extract_text(folder_path):
44
  recursive=True
45
  )
46
  documents = loader.load()
47
- # doc_sources = [doc.metadata["source"] for doc in documents]
48
- # loader = TextLoader(file, encoding = "utf-8")
49
- # return doc_sources
50
  return "\n".join([doc.page_content for doc in documents])
51
 
52
  # ------------------------------
@@ -54,7 +52,7 @@ def extract_text(folder_path):
54
  # ------------------------------
55
  @st.cache_resource
56
  def build_faiss(_docs):
57
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
58
  return FAISS.from_documents(_docs, embeddings)
59
 
60
  docs = []
@@ -62,8 +60,6 @@ db = None
62
 
63
  query = st.text_input("πŸ’¬ Ask a question about MSCI Indexes", placeholder="MSCI World IMI Index")
64
 
65
- # placeholder = st.empty()
66
-
67
  if uploaded_file:
68
  text = extract_text(uploaded_file)
69
  if text:
@@ -71,12 +67,8 @@ if uploaded_file:
71
  docs = [Document(page_content=chunk) for chunk in splitter.split_text(text)]
72
  db = build_faiss(docs)
73
  st.success("βœ… Knowledge Base ready! From :- https://www.msci.com/indexes#featured-indexes")
74
- # st.info("You can ask any question regarding data feed to model is as below!")
75
- # with placeholder:
76
- # long_text = st.text_area(text, height=150, disabled=True)
77
 
78
  if query and db:
79
- # placeholder.empty()
80
  retriever = db.as_retriever(search_kwargs={"k": 3})
81
  retrieved_docs = retriever.get_relevant_documents(query)
82
  context = "\n".join([doc.page_content for doc in retrieved_docs])
@@ -89,7 +81,6 @@ if query and db:
89
  top_p=0.9
90
  )
91
 
92
- # Extract only what comes after "Answer:"
93
  generated = result[0]["generated_text"]
94
  answer_only = generated.split("Answer:")[-1].strip()
95
 
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import os
4
  import docx2txt
5
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  from langchain_community.vectorstores import FAISS
 
9
  from langchain.schema import Document
10
  from sentence_transformers import SentenceTransformer
11
  from langchain_community.document_loaders import DirectoryLoader, TextLoader
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
 
15
+ model_name = os.getenv("MODEL_NAME")
16
+ embedding_model_name = os.getenv("EMBEDDING_MODEL_NAME")
17
  # ------------------------------
18
  # Title
19
  # ------------------------------
 
24
  # ------------------------------
25
  @st.cache_resource
26
  def load_model():
27
+ # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
  model = AutoModelForCausalLM.from_pretrained(model_name)
30
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
 
33
  generator = load_model()
34
 
35
  # ------------------------------
36
+ # Extract Text
37
  # ------------------------------
38
  uploaded_file = "./msci"
39
 
 
 
 
40
  def extract_text(folder_path):
41
  loader = DirectoryLoader(
42
  path=folder_path,
 
45
  recursive=True
46
  )
47
  documents = loader.load()
 
 
 
48
  return "\n".join([doc.page_content for doc in documents])
49
 
50
  # ------------------------------
 
52
  # ------------------------------
53
  @st.cache_resource
54
  def build_faiss(_docs):
55
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
56
  return FAISS.from_documents(_docs, embeddings)
57
 
58
  docs = []
 
60
 
61
  query = st.text_input("πŸ’¬ Ask a question about MSCI Indexes", placeholder="MSCI World IMI Index")
62
 
 
 
63
  if uploaded_file:
64
  text = extract_text(uploaded_file)
65
  if text:
 
67
  docs = [Document(page_content=chunk) for chunk in splitter.split_text(text)]
68
  db = build_faiss(docs)
69
  st.success("βœ… Knowledge Base ready! From :- https://www.msci.com/indexes#featured-indexes")
 
 
 
70
 
71
  if query and db:
 
72
  retriever = db.as_retriever(search_kwargs={"k": 3})
73
  retrieved_docs = retriever.get_relevant_documents(query)
74
  context = "\n".join([doc.page_content for doc in retrieved_docs])
 
81
  top_p=0.9
82
  )
83
 
 
84
  generated = result[0]["generated_text"]
85
  answer_only = generated.split("Answer:")[-1].strip()
86