claytonsds commited on
Commit
6f39035
·
verified ·
1 Parent(s): cc8abc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -46
app.py CHANGED
@@ -7,46 +7,24 @@ from langchain_community.document_loaders import UnstructuredURLLoader
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
10
- from langchain_core.prompts import PromptTemplate
11
- from langchain_core.output_parsers import StrOutputParser
12
 
13
- # Get HuggingFace API token from environment variables
14
- token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
15
 
 
 
16
 
17
  # ------------------------
18
- # LLM Model (LLaMA 2)
19
  # ------------------------
20
- pipe = pipeline(
21
- task="text-generation",
22
- model="HuggingFaceTB/SmolLM2-360M",
23
- temperature=0.7,
24
- max_new_tokens=512,
25
- token=token,
26
- device_map="auto"
27
  )
28
 
29
- llm = HuggingFacePipeline(
30
- pipeline=pipe,
31
- model_kwargs={"temperature": 0.7}
32
- )
33
-
34
- # ------------------------
35
- # Prompt template
36
- # ------------------------
37
- prompt = PromptTemplate.from_template(
38
- """Given the following extracted parts of a long document and a question, create a final answer with references.
39
- If you don't know the answer, just say that you don't know.
40
- Question: {question}"""
41
- )
42
 
43
  # Global variable to store the QA chain
44
- simple_chain = None
45
 
46
- # ------------------------
47
- # Function to process URLs with real-time logging
48
- # ------------------------
49
- # ------------------------
50
  # Paths to save FAISS and URLs
51
  # ------------------------
52
  FAISS_FILE = "vectorstore.pkl"
@@ -56,7 +34,7 @@ URLS_FILE = "urls.pkl"
56
  # Function to process URLs with logging and FAISS management
57
  # ------------------------
58
  def process_urls_with_logs(url1, url2, url3):
59
- global simple_chain
60
 
61
  urls = [url1, url2, url3]
62
  urls = [u.strip() for u in urls if u.strip() != ""]
@@ -101,9 +79,9 @@ def process_urls_with_logs(url1, url2, url3):
101
  pickle.dump(urls, f)
102
 
103
  print("Initializing LLM chain...")
104
- retriever = vectorstore.as_retriever()
105
- from langchain_core.runnables import RunnableSequence
106
- simple_chain = RunnableSequence(prompt, llm, StrOutputParser())
107
  return "FAISS successfully created/recreated!"
108
 
109
  else:
@@ -112,23 +90,24 @@ def process_urls_with_logs(url1, url2, url3):
112
  with open(FAISS_FILE, "rb") as f:
113
  vectorstore = pickle.load(f)
114
 
115
- retriever = vectorstore.as_retriever()
116
- from langchain_core.runnables import RunnableSequence
117
- simple_chain = RunnableSequence(prompt, llm, StrOutputParser())
118
-
119
  return "Existing FAISS loaded."
120
 
121
  # ------------------------
122
  # Function to answer questions
123
  # ------------------------
124
  def ask_question(question):
125
- global simple_chain
126
 
127
- if simple_chain is None:
128
  return "Please process URLs first."
129
 
130
- result = simple_chain.invoke({"question": question})
131
- return result
 
 
 
 
132
 
133
  # ------------------------
134
  # Gradio Interface
@@ -158,6 +137,7 @@ with gr.Blocks() as app:
158
 
159
  ask_btn = gr.Button("Ask")
160
  answer_output = gr.Textbox(label="Answer", lines=8)
 
161
 
162
  # Connect buttons to suas funções
163
  process_btn.click(
@@ -167,10 +147,10 @@ with gr.Blocks() as app:
167
  )
168
 
169
  ask_btn.click(
170
- ask_question,
171
- inputs=question_box,
172
- outputs=answer_output
173
- )
174
 
175
  # Launch the Gradio app
176
  app.launch()
 
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import FAISS
9
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
 
11
 
 
 
12
 
13
+ # Get HuggingFace API token from environment variables
14
+ token = os.environ.get("API_TOKEN")
15
 
16
  # ------------------------
17
+ # LLM
18
  # ------------------------
19
+ llm = ChatGoogleGenerativeAI(
20
+ model="gemini-2.5-flash",
21
+ temperature=0.7,api_key = token
 
 
 
 
22
  )
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Global variable to store the QA chain
26
+ chain = None
27
 
 
 
 
 
28
  # Paths to save FAISS and URLs
29
  # ------------------------
30
  FAISS_FILE = "vectorstore.pkl"
 
34
  # Function to process URLs with logging and FAISS management
35
  # ------------------------
36
  def process_urls_with_logs(url1, url2, url3):
37
+ global chain
38
 
39
  urls = [url1, url2, url3]
40
  urls = [u.strip() for u in urls if u.strip() != ""]
 
79
  pickle.dump(urls, f)
80
 
81
  print("Initializing LLM chain...")
82
+
83
+ chain = RetrievalQAWithSourcesChain.from_llm( llm=llm, retriever=vectorstore.as_retriever())
84
+
85
  return "FAISS successfully created/recreated!"
86
 
87
  else:
 
90
  with open(FAISS_FILE, "rb") as f:
91
  vectorstore = pickle.load(f)
92
 
93
+ chain = RetrievalQAWithSourcesChain.from_llm( llm=llm, retriever=vectorstore.as_retriever())
 
 
 
94
  return "Existing FAISS loaded."
95
 
96
  # ------------------------
97
  # Function to answer questions
98
  # ------------------------
99
  def ask_question(question):
100
+ global chain
101
 
102
+ if chain is None:
103
  return "Please process URLs first."
104
 
105
+ result = chain.invoke({'question': question})
106
+
107
+ answer = result.get("answer", "")
108
+ sources = result.get("sources", "")
109
+
110
+ return answer, sources
111
 
112
  # ------------------------
113
  # Gradio Interface
 
137
 
138
  ask_btn = gr.Button("Ask")
139
  answer_output = gr.Textbox(label="Answer", lines=8)
140
+ sources_output = gr.Textbox(label="Sources", lines=4)
141
 
142
  # Connect buttons to suas funções
143
  process_btn.click(
 
147
  )
148
 
149
  ask_btn.click(
150
+ ask_question,
151
+ inputs=question_box,
152
+ outputs=[answer_output, sources_output]
153
+ )
154
 
155
  # Launch the Gradio app
156
  app.launch()