RAG_Audio_files / app.py
dnzblgn's picture
Update app.py
49807da verified
import gradio as gr
import whisper
import os
import tempfile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
# Load Whisper model (change to "base", "small", or "medium" for better accuracy)
model = whisper.load_model("base")
# Model config for Hugging Face Inference API
hub = {
"HF_MODEL_ID": "mistralai/Mistral-7B-Instruct-v0.3",
"HF_TASK": "text-generation",
"HF_API_TOKEN": os.environ.get("HUGGINGFACE_API_TOKEN")
}
# Global state
vector_db = None
qa_chain = None
chat_memory = None
# Transcribe and set up RAG
def transcribe_and_setup(audio_file_path):
global vector_db, qa_chain, chat_memory
if audio_file_path is None:
return "No audio uploaded.", []
result = model.transcribe(audio_file_path)
transcript = result["text"]
# Split and embed
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
splits = text_splitter.create_documents([transcript])
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = FAISS.from_documents(splits, embeddings)
# QA setup
chat_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
retriever = vector_db.as_retriever()
llm = HuggingFaceEndpoint(
repo_id=hub["HF_MODEL_ID"],
task=hub["HF_TASK"],
huggingfacehub_api_token=hub["HF_API_TOKEN"],
temperature=0.5,
max_new_tokens=512
)
qa_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=chat_memory)
return "Transcription complete! Ready for questions.", [] # Empty chat reset
# Handle conversation
def answer_question(question, chat_history):
global qa_chain, chat_memory
if qa_chain is None:
return "Please upload and process an audio file first.", chat_history
response = qa_chain.invoke({
"question": question,
"chat_history": chat_memory.load_memory_variables({})["chat_history"]
})
# Just show back-and-forth messages
chat_history.append([question, response["answer"]])
return "", chat_history
# Custom CSS
custom_css = """
.chatbox .message.user, .chatbox .message.bot {
background-color: #1e3d2f !important;
color: #ffffff !important;
border-radius: 10px !important;
padding: 10px !important;
margin: 5px !important;
}
.chatbox .message.user::before, .chatbox .message.bot::before {
content: none !important;
}
"""
# Gradio app
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown("## πŸŽ™οΈ **Conversational Audio Chatbot**")
gr.Markdown("Upload an audio file, let the AI process it, and ask any questions!")
with gr.Row():
with gr.Column(scale=1):
audio_input = gr.Audio(type="filepath", label="🎧 Upload Audio")
transcribe_button = gr.Button("πŸš€ Process Audio")
status_output = gr.Textbox(label="πŸ› οΈ Status", interactive=False)
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="πŸ’¬ Chat with your audio", elem_classes=["chatbox"])
question_input = gr.Textbox(label="Type your question", placeholder="Ask about the audio...")
ask_button = gr.Button("πŸ’‘ Ask")
transcribe_button.click(
fn=transcribe_and_setup,
inputs=audio_input,
outputs=[status_output, chatbot]
)
ask_button.click(
fn=answer_question,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
demo.launch()