Viper51 commited on
Commit
ebe7149
·
verified ·
1 Parent(s): 750dd7d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +87 -99
src/streamlit_app.py CHANGED
@@ -1,34 +1,39 @@
 
1
  """
2
- ChatYT Streamlit App (API-Only Version)
3
 
4
  This Streamlit app enables you to:
5
  * Summarise YouTube videos
6
  * Ask questions about the topics discussed in the video
7
 
8
- It uses Google's Gemini APIs for all AI tasks.
9
  """
10
 
11
  import streamlit as st
12
  import yt_dlp
13
  import os
14
- import textwrap
15
  from langchain_core.documents import Document
16
- from langchain.text_splitter import RecursiveCharacterTextSplitter
 
17
  from langchain_chroma import Chroma
18
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
 
 
 
 
19
  import google.generativeai as genai
20
- from langchain.prompts import ChatPromptTemplate
21
- import time # To simulate progress
22
 
23
  # --- App Configuration ---
24
  st.set_page_config(
25
- page_title="ChatYT",
26
  page_icon="📺",
27
  layout="wide",
28
  )
29
 
30
  st.title("📺 ChatYT: Chat with any YouTube Video")
31
- st.caption("Summarize and ask questions about any YouTube video using Google's Gemini APIs.")
32
 
33
  # --- API Key Handling ---
34
  GEMINI_API_KEY = st.secrets.get("GEMINI_API_KEY")
@@ -42,6 +47,7 @@ if not GEMINI_API_KEY:
42
  st.error("Please provide your Gemini API Key in the sidebar to continue.")
43
  st.stop()
44
 
 
45
  try:
46
  genai.configure(api_key=GEMINI_API_KEY)
47
  except Exception as e:
@@ -76,6 +82,7 @@ def compress_audio(input_file, output_file="compressed.mp3"):
76
  def speech_to_text(audio_file):
77
  """
78
  Transcribes audio using the Gemini API.
 
79
  """
80
  try:
81
  model = genai.GenerativeModel("gemini-2.5-flash")
@@ -102,110 +109,65 @@ def speech_to_text(audio_file):
102
  @st.cache_data(show_spinner="Summarizing text...")
103
  def summarize_text_api(text):
104
  """
105
- Summarizes the text using the Gemini API.
106
  """
107
- model = genai.GenerativeModel("gemini-2.5-flash")
108
- prompt = f"""Please provide a concise, high-level summary of the following text:
 
 
 
 
 
109
  ---
110
  {text}
111
  ---
112
  Provide only the summary."""
 
 
 
 
113
 
114
  try:
115
- response = model.generate_content(prompt)
116
- if response.candidates and response.candidates[0].content.parts:
117
- return response.candidates[0].content.parts[0].text
118
- else:
119
- return "Error: Could not summarize text."
120
  except Exception as e:
121
  st.error(f"An error occurred during summarization: {e}")
122
  return f"Error: {e}"
123
 
124
  @st.cache_data(show_spinner="Generating embeddings...")
125
- def generate_embeddings(text):
126
  """
127
  Splits text, generates embeddings via API, and stores in ChromaDB.
 
128
  """
129
  doc = Document(page_content=text, metadata={"source": "youtube"})
 
130
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
131
  chunks = splitter.split_documents([doc])
132
 
133
  try:
134
- embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
135
- # Using a unique persist_directory for each session
136
- # Note: In a real-world deployed Streamlit app, this directory is temporary.
137
- # For persistence, a proper vector DB server would be needed.
138
  db = Chroma.from_documents(chunks, embeddings)
139
  return db
140
  except Exception as e:
141
  st.error(f"An error occurred during embedding generation: {e}")
142
  return None
143
 
144
- # --- Q&A Functions ---
145
-
146
- def closest(query, db):
147
- """
148
- Finds the most relevant text chunks from the vector database.
149
- """
150
- if db is None:
151
- st.warning("Database not initialized.")
152
- return None
153
- try:
154
- results = db.similarity_search(query, k=3)
155
- if len(results) == 0:
156
- return None
157
- return results
158
- except Exception as e:
159
- st.error(f"Error during similarity search: {e}")
160
- return None
161
-
162
-
163
- def create_prompt(results, question):
164
- """
165
- Creates a prompt for the Q&A model based on retrieved chunks.
166
- """
167
- PROMPT = """Answer the following questions based only on the following context:
168
- {context}
169
- ---
170
- Answer the question based on the above context:
171
- {que}
172
- """
173
- if not results:
174
- return "Sorry, I couldn’t find anything relevant in the video transcript."
175
-
176
- context_text = "\n\n---\n\n".join(
177
- doc.page_content for doc in results
178
- )
179
- prompt_template = ChatPromptTemplate.from_template(PROMPT)
180
- return prompt_template.format(context=context_text, que=question)
181
-
182
- def answer_llm(question, closest_chunks):
183
- """
184
- Answers the question using the Gemini API and context.
185
- """
186
- model = genai.GenerativeModel("gemini-2.5-flash")
187
- prompt = create_prompt(closest_chunks, question)
188
-
189
- if prompt == "Sorry, I couldn’t find anything relevant in the video transcript.":
190
- return prompt
191
-
192
- try:
193
- response = model.generate_content(prompt)
194
- if response.candidates and response.candidates[0].content.parts:
195
- return response.candidates[0].content.parts[0].text
196
- else:
197
- return "No answer generated."
198
- except Exception as e:
199
- st.error(f"An error occurred during Q&A: {e}")
200
- return f"Error: {e}"
201
 
202
  # --- Streamlit UI Components ---
203
 
204
  # Initialize session state variables
205
  if "summary" not in st.session_state:
206
  st.session_state.summary = ""
207
- if "db" not in st.session_state:
208
- st.session_state.db = None
209
  if "video_title" not in st.session_state:
210
  st.session_state.video_title = ""
211
  if "chat_history" not in st.session_state:
@@ -217,6 +179,12 @@ if st.button("Process Video", key="process_video"):
217
  if url:
218
  with st.spinner("Processing video... This may take a few minutes."):
219
  try:
 
 
 
 
 
 
220
  # 1. Download
221
  audio_file, video_title = download_audio(url)
222
  st.session_state.video_title = video_title
@@ -230,20 +198,41 @@ if st.button("Process Video", key="process_video"):
230
  st.error(f"Failed to transcribe: {text}")
231
  st.stop()
232
 
233
- # 4. Summarize
234
  summary = summarize_text_api(text)
235
- if "Error:" in summary:
236
- st.error(f"Failed to summarize: {summary}")
237
- st.session_state.summary = "Could not generate summary."
238
- else:
239
- st.session_state.summary = summary
240
 
241
- # 5. Embed
242
- db = generate_embeddings(text)
243
- if db is None:
244
- st.error("Failed to create vector database.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  else:
246
- st.session_state.db = db
247
 
248
  # Clean up local files
249
  try:
@@ -252,8 +241,6 @@ if st.button("Process Video", key="process_video"):
252
  except OSError as e:
253
  st.warning(f"Could not clean up audio files: {e}")
254
 
255
- st.success("Video processed successfully!")
256
-
257
  except Exception as e:
258
  st.error(f"An error occurred during video processing: {e}")
259
  else:
@@ -274,20 +261,21 @@ if st.session_state.summary:
274
 
275
  # Chat input
276
  if prompt := st.chat_input("Ask a question about the video..."):
277
- if st.session_state.db:
278
  # Add user message to history
279
  st.session_state.chat_history.append(("user", prompt))
280
  with st.chat_message("user"):
281
  st.markdown(prompt)
282
 
283
- # Generate and display bot response
284
  with st.chat_message("assistant"):
285
  with st.spinner("Thinking..."):
286
- chunks = closest(prompt, st.session_state.db)
287
- answer = answer_llm(prompt, chunks)
288
  st.markdown(answer)
289
 
290
  # Add bot message to history
291
  st.session_state.chat_history.append(("assistant", answer))
292
  else:
293
- st.error("The vector database is not loaded. Please process a video first.")
 
 
1
+ # -*- coding: utf-8 -*-
2
  """
3
+ ChatYT Streamlit App (LCEL Chain Version)
4
 
5
  This Streamlit app enables you to:
6
  * Summarise YouTube videos
7
  * Ask questions about the topics discussed in the video
8
 
9
+ It uses LangChain Expression Language (LCEL) with Google's Gemini APIs.
10
  """
11
 
12
  import streamlit as st
13
  import yt_dlp
14
  import os
15
+ # Corrected import: Document is now in langchain_core.documents
16
  from langchain_core.documents import Document
17
+ # Corrected import: RecursiveCharacterTextSplitter is in its own package
18
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
19
  from langchain_chroma import Chroma
20
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
21
+ # Corrected import: ChatPromptTemplate is now in langchain_core.prompts
22
+ from langchain_core.prompts import ChatPromptTemplate
23
+ from langchain_core.output_parsers import StrOutputParser
24
+ from langchain_core.runnables import RunnablePassthrough
25
  import google.generativeai as genai
26
+ import time
 
27
 
28
  # --- App Configuration ---
29
  st.set_page_config(
30
+ page_title="ChatYT (LangChain)",
31
  page_icon="📺",
32
  layout="wide",
33
  )
34
 
35
  st.title("📺 ChatYT: Chat with any YouTube Video")
36
+ st.caption("Summarize and ask questions about any YouTube video using LangChain and Google Gemini.")
37
 
38
  # --- API Key Handling ---
39
  GEMINI_API_KEY = st.secrets.get("GEMINI_API_KEY")
 
47
  st.error("Please provide your Gemini API Key in the sidebar to continue.")
48
  st.stop()
49
 
50
+ # Configure the genai library (still needed for file upload)
51
  try:
52
  genai.configure(api_key=GEMINI_API_KEY)
53
  except Exception as e:
 
82
  def speech_to_text(audio_file):
83
  """
84
  Transcribes audio using the Gemini API.
85
+ (This function uses the base genai library for file upload)
86
  """
87
  try:
88
  model = genai.GenerativeModel("gemini-2.5-flash")
 
109
  @st.cache_data(show_spinner="Summarizing text...")
110
  def summarize_text_api(text):
111
  """
112
+ Summarizes the text using a LangChain chain.
113
  """
114
+ # 1. Define the LLM
115
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash",
116
+ temperature=0.3,
117
+ google_api_key=GEMINI_API_KEY)
118
+
119
+ # 2. Define the Prompt
120
+ prompt_template = """Please provide a concise, high-level summary of the following text:
121
  ---
122
  {text}
123
  ---
124
  Provide only the summary."""
125
+ summarize_prompt = ChatPromptTemplate.from_template(prompt_template)
126
+
127
+ # 3. Define the Chain
128
+ summarize_chain = summarize_prompt | llm | StrOutputParser()
129
 
130
  try:
131
+ # 4. Invoke the Chain
132
+ response = summarize_chain.invoke({"text": text})
133
+ return response
 
 
134
  except Exception as e:
135
  st.error(f"An error occurred during summarization: {e}")
136
  return f"Error: {e}"
137
 
138
  @st.cache_data(show_spinner="Generating embeddings...")
139
+ def generate_embeddings_db(text):
140
  """
141
  Splits text, generates embeddings via API, and stores in ChromaDB.
142
+ Returns the Chroma database object.
143
  """
144
  doc = Document(page_content=text, metadata={"source": "youtube"})
145
+ # This now uses the imported RecursiveCharacterTextSplitter
146
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
147
  chunks = splitter.split_documents([doc])
148
 
149
  try:
150
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001",
151
+ google_api_key=GEMINI_API_KEY)
 
 
152
  db = Chroma.from_documents(chunks, embeddings)
153
  return db
154
  except Exception as e:
155
  st.error(f"An error occurred during embedding generation: {e}")
156
  return None
157
 
158
+ def format_docs(docs):
159
+ """Helper function to format retrieved documents into a string."""
160
+ if not docs:
161
+ return "No relevant context found."
162
+ return "\n\n---\n\n".join(doc.page_content for doc in docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  # --- Streamlit UI Components ---
165
 
166
  # Initialize session state variables
167
  if "summary" not in st.session_state:
168
  st.session_state.summary = ""
169
+ if "rag_chain" not in st.session_state:
170
+ st.session_state.rag_chain = None
171
  if "video_title" not in st.session_state:
172
  st.session_state.video_title = ""
173
  if "chat_history" not in st.session_state:
 
179
  if url:
180
  with st.spinner("Processing video... This may take a few minutes."):
181
  try:
182
+ # Reset state
183
+ st.session_state.summary = ""
184
+ st.session_state.rag_chain = None
185
+ st.session_state.video_title = ""
186
+ st.session_state.chat_history = []
187
+
188
  # 1. Download
189
  audio_file, video_title = download_audio(url)
190
  st.session_state.video_title = video_title
 
198
  st.error(f"Failed to transcribe: {text}")
199
  st.stop()
200
 
201
+ # 4. Summarize (using the new chain function)
202
  summary = summarize_text_api(text)
203
+ st.session_state.summary = summary
 
 
 
 
204
 
205
+ # 5. Embed and create DB
206
+ db = generate_embeddings_db(text)
207
+
208
+ if db:
209
+ # 6. Create RAG Chain and store it in session state
210
+ llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash",
211
+ temperature=0.3,
212
+ google_api_key=GEMINI_API_KEY)
213
+
214
+ retriever = db.as_retriever(search_kwargs={"k": 3})
215
+
216
+ PROMPT_TEMPLATE = """Answer the following questions based only on the following context:
217
+ {context}
218
+ ---
219
+ Answer the question based on the above context:
220
+ {question}
221
+ """
222
+ prompt = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
223
+
224
+ # This is the RAG chain
225
+ rag_chain = (
226
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
227
+ | prompt
228
+ | llm
229
+ | StrOutputParser()
230
+ )
231
+
232
+ st.session_state.rag_chain = rag_chain
233
+ st.success("Video processed and Q&A chat is ready!")
234
  else:
235
+ st.error("Failed to create vector database.")
236
 
237
  # Clean up local files
238
  try:
 
241
  except OSError as e:
242
  st.warning(f"Could not clean up audio files: {e}")
243
 
 
 
244
  except Exception as e:
245
  st.error(f"An error occurred during video processing: {e}")
246
  else:
 
261
 
262
  # Chat input
263
  if prompt := st.chat_input("Ask a question about the video..."):
264
+ if st.session_state.rag_chain:
265
  # Add user message to history
266
  st.session_state.chat_history.append(("user", prompt))
267
  with st.chat_message("user"):
268
  st.markdown(prompt)
269
 
270
+ # Generate and display bot response by invoking the chain
271
  with st.chat_message("assistant"):
272
  with st.spinner("Thinking..."):
273
+ # Here we just invoke the chain with the prompt!
274
+ answer = st.session_state.rag_chain.invoke(prompt)
275
  st.markdown(answer)
276
 
277
  # Add bot message to history
278
  st.session_state.chat_history.append(("assistant", answer))
279
  else:
280
+ st.error("The Q&A chain is not ready. Please process a video first.")
281
+