Spaces:
Runtime error
Runtime error
| import torch | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.document_loaders import JSONLoader | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
| from langchain_core.messages import AIMessage, HumanMessage | |
| from langchain.memory.buffer_window import ConversationBufferWindowMemory | |
| from langchain_community.llms import HuggingFaceHub | |
| from langchain.chains import ( | |
| LLMChain, | |
| StuffDocumentsChain, | |
| MapReduceDocumentsChain, | |
| ReduceDocumentsChain, | |
| ) | |
| from gradio_client import Client | |
| import gradio as gr | |
| import yt_dlp | |
| import json | |
| import gc | |
| import datetime | |
| import os | |
| import numpy as np | |
| """Prepare data""" | |
| whisper_jax_api = "https://sanchit-gandhi-whisper-jax.hf.space/" | |
| whisper_jax = Client(whisper_jax_api) | |
| def transcribe_audio(audio_path, task="transcribe", return_timestamps=True) -> str: | |
| text, runtime = whisper_jax.predict( | |
| audio_path, | |
| task, | |
| return_timestamps, | |
| api_name="/predict_1", | |
| ) | |
| return text | |
| def format_whisper_jax_output( | |
| whisper_jax_output: str, max_duration: int = 60 | |
| ) -> list[dict]: | |
| """Whisper JAX outputs are in the format | |
| '[00:00.000 -> 00:00.000] text\n[00:00.000 -> 00:00.000] text'. | |
| Returns a list of dict with keys 'start', 'end', 'text' | |
| The segments from whisper jax output are merged to form paragraphs. | |
| `max_duration` controls how many seconds of the audio's transcripts are merged | |
| For example, if `max_duration`=60, in the final output, each segment is roughly | |
| 60 seconds. | |
| """ | |
| final_output = [] | |
| max_duration = datetime.timedelta(seconds=max_duration) | |
| segments = whisper_jax_output.split("\n") | |
| current_start = datetime.datetime.strptime("00:00", "%M:%S") | |
| current_text = "" | |
| for i, seg in enumerate(segments): | |
| text = seg.split("]")[-1].strip() | |
| current_text += " " + text | |
| # Sometimes whisper jax returns None for timestamp | |
| try: | |
| end = datetime.datetime.strptime(seg[14:19], "%M:%S") | |
| except ValueError: | |
| end = current_start + max_duration | |
| if i == len(segments) - 1: | |
| final_output.append( | |
| { | |
| "start": current_start.strftime("%H:%M:%S"), | |
| "end": end.strftime("%H:%M:%S"), | |
| "text": current_text.strip(), | |
| } | |
| ) | |
| else: | |
| if end - current_start >= max_duration and current_text[-1] == ".": | |
| # If we have exceeded max duration, check whether we have | |
| # reached the end of a sentence. If not, keep merging. | |
| final_output.append( | |
| { | |
| "start": current_start.strftime("%H:%M:%S"), | |
| "end": end.strftime("%H:%M:%S"), | |
| "text": current_text.strip(), | |
| } | |
| ) | |
| # Update current start and text | |
| current_start = end | |
| current_text = "" | |
| return final_output | |
| def yt_audio_to_text(url: str, max_duration: int = 60): | |
| """Given a YouTube url, download audio and transcribe it to text. Reformat | |
| the output from Whisper JAX and save the final result in a json file. | |
| """ | |
| progress = gr.Progress() | |
| progress(0.1) | |
| with yt_dlp.YoutubeDL( | |
| {"extract_audio": True, "format": "bestaudio", "outtmpl": "audio.mp3"} | |
| ) as video: | |
| info_dict = video.extract_info(url, download=False) | |
| global video_title | |
| video_title = info_dict["title"] | |
| video.download(url) | |
| progress(0.4) | |
| audio_file = "audio.mp3" | |
| result = transcribe_audio(audio_file, return_timestamps=True) | |
| progress(0.7) | |
| result = format_whisper_jax_output(result, max_duration=max_duration) | |
| progress(0.9) | |
| with open("audio.json", "w") as f: | |
| json.dump(result, f) | |
| os.remove(audio_file) | |
| """Load data""" | |
| def metadata_func(record: dict, metadata: dict) -> dict: | |
| """This function is used to tell the Langchain loader the keys that | |
| contain metadata and extract them. | |
| """ | |
| metadata["start"] = record.get("start") | |
| metadata["end"] = record.get("end") | |
| metadata["source"] = metadata["start"] + " -> " + metadata["end"] | |
| return metadata | |
| def load_data(): | |
| loader = JSONLoader( | |
| file_path="audio.json", | |
| jq_schema=".[]", | |
| content_key="text", | |
| metadata_func=metadata_func, | |
| ) | |
| data = loader.load() | |
| os.remove("audio.json") | |
| return data | |
| """Create embeddings and vector store""" | |
| embedding_model_name = "sentence-transformers/all-mpnet-base-v2" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| embedding_model_kwargs = {"device": device} | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=embedding_model_name, model_kwargs=embedding_model_kwargs | |
| ) | |
| def create_vectordb(data, n_retrieved_docs: int, collection_name="YouTube"): | |
| """Returns a retriever which is used to fetch relevant documents from | |
| the vector database. | |
| `n_retrieved_docs` is the number of retrieved documents. | |
| """ | |
| vectordb = Chroma.from_documents( | |
| documents=data, embedding=embeddings, collection_name=collection_name | |
| ) | |
| n_docs = len(vectordb.get()["ids"]) | |
| retriever = vectordb.as_retriever( | |
| search_type="mmr", search_kwargs={"k": n_retrieved_docs, "fetch_k": n_docs} | |
| ) | |
| return retriever | |
| """Load LLM""" | |
| repo_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"max_new_tokens": 1000}) | |
| llm.client.api_url = 'https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3' | |
| """Summarisation""" | |
| # Map | |
| map_template = """Summarise the following text: | |
| {docs} | |
| Answer:""" | |
| map_prompt = PromptTemplate.from_template(map_template) | |
| map_chain = LLMChain(llm=llm, prompt=map_prompt) | |
| # Reduce | |
| reduce_template = """The following is a set of summaries: | |
| {docs} | |
| Take these and distill it into a final, consolidated summary of the main themes \ | |
| in 150 words or less. | |
| Answer:""" | |
| reduce_prompt = PromptTemplate.from_template(reduce_template) | |
| reduce_chain = LLMChain(llm=llm, prompt=reduce_prompt) | |
| # Takes a list of documents, combines them into a single string, and passes this to llm | |
| combine_documents_chain = StuffDocumentsChain( | |
| llm_chain=reduce_chain, document_variable_name="docs" | |
| ) | |
| # Combines and iteravely reduces the mapped documents | |
| reduce_documents_chain = ReduceDocumentsChain( | |
| # This is final chain that is called. | |
| combine_documents_chain=combine_documents_chain, | |
| # If documents exceed context for `StuffDocumentsChain` | |
| collapse_documents_chain=combine_documents_chain, | |
| # The maximum number of tokens to group documents into. | |
| token_max=4000, | |
| ) | |
| # Combining documents by mapping a chain over them, then combining results | |
| map_reduce_chain = MapReduceDocumentsChain( | |
| # Map chain | |
| llm_chain=map_chain, | |
| # Reduce chain | |
| reduce_documents_chain=reduce_documents_chain, | |
| # The variable name in the llm_chain to put the documents in | |
| document_variable_name="docs", | |
| # Return the results of the map steps in the output | |
| return_intermediate_steps=False, | |
| ) | |
| def get_summary(documents) -> str: | |
| summary = map_reduce_chain.invoke(documents, return_only_outputs=True) | |
| return summary["output_text"].strip() | |
| """Contextualising the question""" | |
| contextualise_q_prompt = PromptTemplate.from_template( | |
| """Given a chat history and the latest user question \ | |
| which might reference the chat history, formulate a \ | |
| standalone question that can be understood without \ | |
| the chat history. Do NOT answer the question, just \ | |
| reformulate it if needed and otherwise return it as is. | |
| Chat history: {chat_history} | |
| Question: {question} | |
| Answer: | |
| """ | |
| ) | |
| contextualise_q_chain = contextualise_q_prompt | llm | |
| """Standalone question chain""" | |
| standalone_prompt = PromptTemplate.from_template( | |
| """Given a chat history and the latest user question, \ | |
| identify whether the question is a standalone question \ | |
| or the question references the chat history. Answer 'yes' \ | |
| if the question is a standalone question, and 'no' if the \ | |
| question references the chat history. Do not answer \ | |
| anything other than 'yes' or 'no'. | |
| Chat history: | |
| {chat_history} | |
| Question: | |
| {question} | |
| Answer: | |
| """ | |
| ) | |
| def format_output(answer: str) -> str: | |
| """All lower case and remove all whitespace to ensure | |
| that the answer given by the LLM is either 'yes' or 'no'. | |
| """ | |
| return "".join(answer.lower().split()) | |
| standalone_chain = standalone_prompt | llm | format_output | |
| """Q&A chain""" | |
| qa_prompt = PromptTemplate.from_template( | |
| """You are an assistant for question-answering tasks. \ | |
| ONLY use the following context to answer the question. \ | |
| Do NOT answer with information that is not contained in \ | |
| the context. If you don't know the answer, just say:\ | |
| "Sorry, I cannot find the answer to that question in the video." | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer: | |
| """ | |
| ) | |
| class YouTubeChatbot: | |
| instance_count = 0 | |
| def __init__( | |
| self, | |
| n_sources: int = 3, | |
| n_retrieved_docs: int = 5, | |
| timestamp_interval: datetime.timedelta = datetime.timedelta(minutes=2), | |
| memory: int = 5, | |
| ): | |
| YouTubeChatbot.instance_count += 1 | |
| self.chatbot_id = YouTubeChatbot.instance_count | |
| self.n_sources = n_sources | |
| self.n_retrieved_docs = n_retrieved_docs | |
| self.timestamp_interval = timestamp_interval | |
| self.chat_history = ConversationBufferWindowMemory(k=memory) | |
| self.retriever = None | |
| self.qa_chain = None | |
| def format_docs(self, docs: list) -> str: | |
| """Combine documents into a single string which will be included | |
| in the prompt given to the LLM. | |
| """ | |
| self.sources = [doc.metadata["start"] for doc in docs] | |
| return "\n\n".join(doc.page_content for doc in docs) | |
| def standalone_question(self, input_: dict) -> str: | |
| """If the question is a not a standalone question, | |
| run contextualise_q_chain. | |
| """ | |
| if input_["standalone"] == "yes": | |
| return contextualise_q_chain | |
| else: | |
| return input_["question"] | |
| def format_answer(self, answer: str) -> str: | |
| """Add timestamps to answers. | |
| """ | |
| if "cannot find the answer" in answer: | |
| return answer.strip() | |
| else: | |
| timestamps = self.filter_timestamps() | |
| answer_with_sources = ( | |
| answer.strip() + " You can find more information " | |
| "at these timestamps: {}.".format(", ".join(timestamps)) | |
| ) | |
| return answer_with_sources | |
| def filter_timestamps(self) -> list[str]: | |
| """Returns a list of timestamps with length less or | |
| equal to `n_sources`. The timestamps are at least an | |
| `timestamp_interval` apart. This prevents returning | |
| a list of timestamps that are too close together. | |
| """ | |
| filtered_timestamps = np.array( | |
| [datetime.datetime.strptime(self.sources[0], "%H:%M:%S")] | |
| ) | |
| i = 1 | |
| while len(filtered_timestamps) < self.n_sources: | |
| try: | |
| new_timestamp = datetime.datetime.strptime(self.sources[i], "%H:%M:%S") | |
| except IndexError: | |
| break | |
| absolute_time_difference = abs(new_timestamp - filtered_timestamps) | |
| if all(absolute_time_difference >= self.timestamp_interval): | |
| filtered_timestamps = np.append(filtered_timestamps, new_timestamp) | |
| i += 1 | |
| filtered_timestamps = [ | |
| timestamp.strftime("%H:%M:%S") for timestamp in filtered_timestamps | |
| ] | |
| filtered_timestamps.sort() | |
| return filtered_timestamps | |
| def process_video(self, url: str, data=None, retriever=None): | |
| """Given a YouTube URL, transcribe YouTube audio to text. | |
| Then set up the vector database. | |
| """ | |
| yt_audio_to_text(url) | |
| data = load_data() | |
| if retriever is not None: | |
| # If we already have documents in the vector store, delete them. | |
| ids = retriever.vectorstore.get()["ids"] | |
| retriever.vectorstore.delete(ids) | |
| retriever = create_vectordb( | |
| data, self.n_retrieved_docs, | |
| collection_name=f"Chatbot{self.chatbot_id}" | |
| ) | |
| return url, data, retriever | |
| def setup_qa_chain(self, retriever, qa_chain=None): | |
| qa_chain = ( | |
| RunnablePassthrough.assign(standalone=standalone_chain) | |
| | { | |
| "question": self.standalone_question, | |
| "context": self.standalone_question | retriever | self.format_docs, | |
| } | |
| | qa_prompt | |
| | llm | |
| ) | |
| return retriever, qa_chain | |
| def setup_chatbot(self, url: str): | |
| _, _, self.retriever = self.process_video(url=url, retriever=self.retriever) | |
| _, self.qa_chain = self.setup_qa_chain(retriever=self.retriever) | |
| def get_answer(self, question: str) -> str: | |
| try: | |
| ai_msg = self.qa_chain.invoke( | |
| {"question": question, "chat_history": self.chat_history} | |
| ) | |
| except AttributeError: | |
| raise AttributeError( | |
| "You haven't setup the chatbot yet. " | |
| "Setup the chatbot by calling the " | |
| "instance method `setup_chatbot`." | |
| ) | |
| self.chat_history.save_context({"question": question}, {"answer": ai_msg}) | |
| answer = self.format_answer(ai_msg) | |
| return answer | |
| """Web app""" | |
| class YouTubeChatbotApp(YouTubeChatbot): | |
| def __init__( | |
| self, | |
| n_sources: int, | |
| n_retrieved_docs: int, | |
| timestamp_interval: datetime.timedelta, | |
| memory: int, | |
| default_youtube_url: str, | |
| ): | |
| super().__init__(n_sources, n_retrieved_docs, timestamp_interval, memory) | |
| self.default_youtube_url = default_youtube_url | |
| self.memory = memory | |
| self.chat_history = None | |
| self.data = None | |
| self.retriever = None | |
| self.qa_chain = None | |
| # Gradio components | |
| self.url_input = None | |
| self.url_button = None | |
| self.app_chat_history = None | |
| self.chatbot = None | |
| self.user_input = None | |
| self.clear_button = None | |
| def greet(self, data, app_chat_history) -> dict: | |
| """Summarise the video and greet the user. | |
| """ | |
| summary_message = f'Here is a summary of the video "{video_title}":' | |
| app_chat_history.append((None, summary_message)) | |
| summary = get_summary(data) | |
| self.data = gr.State(None) | |
| app_chat_history.append((None, summary)) | |
| greeting_message = ( | |
| "You can ask me anything about the video. " "I will do my best to answer!" | |
| ) | |
| app_chat_history.append((None, greeting_message)) | |
| return {self.app_chat_history: app_chat_history, self.chatbot: app_chat_history} | |
| def question(self, user_question: str, app_chat_history) -> dict: | |
| """Display the question asked by the user in the chat window, | |
| and delete from the input textbox. | |
| """ | |
| app_chat_history.append((user_question, None)) | |
| return { | |
| self.user_input: "", | |
| self.app_chat_history: app_chat_history, | |
| self.chatbot: app_chat_history, | |
| } | |
| def respond(self, qa_chain, chat_history, app_chat_history) -> dict: | |
| """Respond to user's latest question""" | |
| question = app_chat_history[-1][0] | |
| try: | |
| ai_msg = qa_chain.invoke( | |
| {"question": question, "chat_history": chat_history} | |
| ) | |
| except AttributeError: | |
| raise gr.Error( | |
| "You need to process the video " "first by pressing the `Go` button." | |
| ) | |
| chat_history.save_context({"question": question}, {"answer": ai_msg}) | |
| answer = self.format_answer(ai_msg) | |
| app_chat_history.append((None, answer)) | |
| return { | |
| self.qa_chain: qa_chain, | |
| self.chat_history: chat_history, | |
| self.app_chat_history: app_chat_history, | |
| self.chatbot: app_chat_history, | |
| } | |
| def clear_chat_history(self, chat_history, app_chat_history): | |
| chat_history.clear() | |
| app_chat_history = [] | |
| return { | |
| self.chat_history: chat_history, | |
| self.app_chat_history: app_chat_history, | |
| self.chatbot: app_chat_history, | |
| } | |
| def launch(self, **kwargs): | |
| with gr.Blocks() as demo: | |
| self.chat_history = gr.State(ConversationBufferWindowMemory(k=self.memory)) | |
| self.app_chat_history = gr.State([]) | |
| self.data = gr.State() | |
| self.retriever = gr.State() | |
| self.qa_chain = gr.State() | |
| # App structure | |
| with gr.Row(): | |
| self.url_input = gr.Textbox( | |
| value=self.default_youtube_url, label="YouTube URL", scale=5 | |
| ) | |
| self.url_button = gr.Button(value="Go", scale=1) | |
| self.chatbot = gr.Chatbot() | |
| self.user_input = gr.Textbox(label="Ask a question:") | |
| self.clear_button = gr.Button(value="Clear") | |
| # App actions | |
| # When a new url is given, clear past chat history and process | |
| # the new video. Set up the Q&A chain with the new video's data. | |
| # Provide a summary of the new video. | |
| self.url_button.click( | |
| self.clear_chat_history, | |
| inputs=[self.chat_history, self.app_chat_history], | |
| outputs=[self.chat_history, self.app_chat_history, self.chatbot], | |
| trigger_mode="once", | |
| ).then( | |
| self.process_video, | |
| inputs=[self.url_input, self.data, self.retriever], | |
| outputs=[self.url_input, self.data, self.retriever], | |
| ).then( | |
| self.setup_qa_chain, | |
| inputs=[self.retriever, self.qa_chain], | |
| outputs=[self.retriever, self.qa_chain], | |
| ).then( | |
| self.greet, | |
| inputs=[self.data, self.app_chat_history], | |
| outputs=[self.app_chat_history, self.chatbot], | |
| ) | |
| # When a user asks a question, display the question in the chat | |
| # window and remove it from the text input area. Then respond | |
| # with the Q&A chain. | |
| self.user_input.submit( | |
| self.question, | |
| inputs=[self.user_input, self.app_chat_history], | |
| outputs=[self.user_input, self.app_chat_history, self.chatbot], | |
| queue=False, | |
| ).then( | |
| self.respond, | |
| inputs=[self.qa_chain, self.chat_history, self.app_chat_history], | |
| outputs=[ | |
| self.qa_chain, | |
| self.chat_history, | |
| self.app_chat_history, | |
| self.chatbot, | |
| ], | |
| ) | |
| # When the `Clear` button is clicked, clear the chat history from | |
| # the chat window. | |
| self.clear_button.click( | |
| self.clear_chat_history, | |
| inputs=[self.chat_history, self.app_chat_history], | |
| outputs=[self.chat_history, self.app_chat_history, self.chatbot], | |
| queue=False, | |
| ) | |
| demo.launch(**kwargs) | |
| if __name__ == "__main__": | |
| app = YouTubeChatbotApp( | |
| n_sources=3, | |
| n_retrieved_docs=5, | |
| timestamp_interval=datetime.timedelta(minutes=2), | |
| memory=5, | |
| default_youtube_url="https://www.youtube.com/watch?v=SZorAJ4I-sA", | |
| ) | |
| app.launch() | |