anshumanpatil commited on
Commit
99102a9
Β·
1 Parent(s): 558d39c

add other parameters in dir

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -1,7 +1,5 @@
1
  import streamlit as st
2
  import pandas as pd
3
- # import os
4
- # import docx2txt
5
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -9,8 +7,6 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain.schema import Document
10
  from sentence_transformers import SentenceTransformer
11
  from langchain_community.document_loaders import DirectoryLoader, TextLoader
12
- # from dotenv import load_dotenv
13
- # load_dotenv()
14
 
15
 
16
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@@ -25,12 +21,14 @@ extracted_text = None
25
  # ------------------------------
26
  # Load Model for pretraining
27
  # ------------------------------
 
28
  def load_model():
29
  # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
30
  tokenizer = AutoTokenizer.from_pretrained(model_name)
31
  model = AutoModelForCausalLM.from_pretrained(model_name)
32
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
33
 
 
34
  def extract_text():
35
  uploaded_data_path = "./msci"
36
  loader = DirectoryLoader(
@@ -63,22 +61,27 @@ with st.spinner("πŸ”„ Loading Knowldge Base..."):
63
  st.title("πŸ“š RAG For MSCI Indexes")
64
  st.markdown("This app uses a local LLM model to answer questions about MSCI Indexes using RAG (Retrieval Augmented Generation).")
65
 
66
- query = st.text_input("πŸ’¬ Ask a question about MSCI Indexes", placeholder="MSCI World IMI Index")
67
-
68
- if query and db and extracted_text and len(docs) > 0:
69
- retriever = db.as_retriever(search_kwargs={"k": 3})
70
- retrieved_docs = retriever.get_relevant_documents(query)
71
- context = "\n".join([doc.page_content for doc in retrieved_docs])
72
-
73
- with st.spinner("πŸ€” Generating answer..."):
74
- result = generator(
75
- f"Context:\n{context}\n\nQuestion: {query}\nAnswer:",
76
- max_new_tokens=150,
77
- temperature=0.5,
78
- top_p=0.9
79
- )
80
-
81
- generated = result[0]["generated_text"]
82
- answer_only = generated.split("Answer:")[-1].strip()
83
-
84
- st.write("πŸ“ Answer:", answer_only)
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
 
 
3
  from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
  from langchain_community.vectorstores import FAISS
5
  from langchain_huggingface import HuggingFaceEmbeddings
 
7
  from langchain.schema import Document
8
  from sentence_transformers import SentenceTransformer
9
  from langchain_community.document_loaders import DirectoryLoader, TextLoader
 
 
10
 
11
 
12
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
21
  # ------------------------------
22
  # Load Model for pretraining
23
  # ------------------------------
24
+ @st.cache_resource
25
  def load_model():
26
  # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = AutoModelForCausalLM.from_pretrained(model_name)
29
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
30
 
31
+ @st.cache_resource
32
  def extract_text():
33
  uploaded_data_path = "./msci"
34
  loader = DirectoryLoader(
 
61
  st.title("πŸ“š RAG For MSCI Indexes")
62
  st.markdown("This app uses a local LLM model to answer questions about MSCI Indexes using RAG (Retrieval Augmented Generation).")
63
 
64
+ with st.form(key='my_form'):
65
+ query = st.text_input("πŸ’¬ Ask a question about MSCI Indexes(Required)", placeholder="MSCI World IMI Index")
66
+ max_new_tokens_model = st.slider("Max New Tokens (Optional):", min_value=50, max_value=500, value=150, step=25)
67
+ temperature_model = st.slider("Temperature (Optional):", min_value=0.0, max_value=0.9, value=0.5, step=0.1)
68
+ submit_button = st.form_submit_button("Submit")
69
+
70
+ if submit_button:
71
+ if query and db and extracted_text and len(docs) > 0:
72
+ retriever = db.as_retriever(search_kwargs={"k": 3})
73
+ retrieved_docs = retriever.get_relevant_documents(query)
74
+ context = "\n".join([doc.page_content for doc in retrieved_docs])
75
+
76
+ with st.spinner("πŸ€” Generating answer..."):
77
+ result = generator(
78
+ f"Context:\n{context}\n\nQuestion: {query}\nAnswer:",
79
+ max_new_tokens=max_new_tokens_model,
80
+ temperature=temperature_model,
81
+ top_p=0.9
82
+ )
83
+
84
+ generated = result[0]["generated_text"]
85
+ answer_only = generated.split("Answer:")[-1].strip()
86
+
87
+ st.write("πŸ“ Answer:", answer_only)