dnzblgn commited on
Commit
4b98922
Β·
verified Β·
1 Parent(s): 7ec7c4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py CHANGED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import whisper
3
+ import os
4
+ import tempfile
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.memory import ConversationBufferMemory
10
+ from langchain_community.llms import HuggingFaceEndpoint
11
+
12
+ # Load Whisper model
13
+ model = whisper.load_model("base")
14
+
15
+ # Global states
16
+ vector_db = None
17
+ qa_chain = None
18
+
19
+ # Function to transcribe and initialize RAG
20
+ def transcribe_and_setup(audio_file):
21
+ global vector_db, qa_chain
22
+
23
+ if audio_file is None:
24
+ return "No audio uploaded.", None, None, ""
25
+
26
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
27
+ tmp.write(audio_file.read())
28
+ tmp_path = tmp.name
29
+
30
+ result = model.transcribe(tmp_path)
31
+ os.remove(tmp_path)
32
+ transcript = result['text']
33
+
34
+ # Build vector DB
35
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
36
+ splits = text_splitter.create_documents([transcript])
37
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
38
+ vector_db = FAISS.from_documents(splits, embeddings)
39
+
40
+ # Create QA chain
41
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
42
+ retriever = vector_db.as_retriever()
43
+ llm = HuggingFaceEndpoint(
44
+ repo_id="mistralai/Mistral-7B-Instruct-v0.2",
45
+ huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
46
+ temperature=0.5,
47
+ max_new_tokens=512,
48
+ task="text-generation"
49
+ )
50
+ qa_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)
51
+
52
+ return "Transcription and RAG setup complete!", transcript, "You can now ask a question."
53
+
54
+ # Function to ask questions
55
+ def answer_question(question):
56
+ global qa_chain
57
+ if qa_chain is None:
58
+ return "Please upload an audio file and process it first."
59
+ response = qa_chain.invoke({"question": question, "chat_history": []})
60
+ return response['answer']
61
+
62
+ # Gradio UI
63
+ with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none !important;}") as demo:
64
+ gr.Markdown("## πŸŽ™οΈ **Audio Intelligence Assistant**")
65
+ gr.Markdown("Upload an audio file, get the transcript, and ask questions about the content!")
66
+
67
+ with gr.Row():
68
+ with gr.Column(scale=1):
69
+ audio_input = gr.Audio(type="file", label="🎧 Upload Audio")
70
+ transcribe_button = gr.Button("πŸš€ Transcribe and Setup RAG")
71
+ status_output = gr.Textbox(label="πŸ› οΈ Status", interactive=False)
72
+ transcript_output = gr.Textbox(label="πŸ“ Transcript", lines=10, interactive=False)
73
+ with gr.Column(scale=1):
74
+ question_input = gr.Textbox(label="❓ Ask a question about the audio", placeholder="What is the audio about?")
75
+ ask_button = gr.Button("πŸ’¬ Ask")
76
+ answer_output = gr.Textbox(label="πŸ€– Answer", lines=5)
77
+
78
+ transcribe_button.click(
79
+ fn=transcribe_and_setup,
80
+ inputs=audio_input,
81
+ outputs=[status_output, transcript_output, answer_output]
82
+ )
83
+
84
+ ask_button.click(
85
+ fn=answer_question,
86
+ inputs=question_input,
87
+ outputs=answer_output
88
+ )
89
+
90
+ demo.launch()