dnzblgn commited on
Commit
b58d41a
Β·
verified Β·
1 Parent(s): b8dc088

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -14
app.py CHANGED
@@ -9,38 +9,46 @@ from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain_community.llms import HuggingFaceEndpoint
11
 
12
- # Load Whisper model
13
  model = whisper.load_model("tiny")
14
 
15
- # Global states
 
 
 
 
 
 
 
16
  vector_db = None
17
  qa_chain = None
18
 
19
- # Function to transcribe and initialize RAG
20
  def transcribe_and_setup(audio_file_path):
21
  global vector_db, qa_chain
22
 
23
  if audio_file_path is None:
24
  return "No audio uploaded.", None, None, ""
25
 
 
26
  result = model.transcribe(audio_file_path)
27
- transcript = result['text']
28
 
29
- # Build vector DB
30
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
31
  splits = text_splitter.create_documents([transcript])
32
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
33
  vector_db = FAISS.from_documents(splits, embeddings)
34
 
35
- # Create QA chain
36
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
37
  retriever = vector_db.as_retriever()
38
  llm = HuggingFaceEndpoint(
39
- repo_id="mistralai/Mistral-7B-v0.1",
40
- huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
 
41
  temperature=0.5,
42
- max_new_tokens=512,
43
- task="text-generation"
44
  )
45
  qa_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)
46
 
@@ -50,14 +58,14 @@ def transcribe_and_setup(audio_file_path):
50
  def answer_question(question):
51
  global qa_chain
52
  if qa_chain is None:
53
- return "Please upload an audio file and process it first."
54
  response = qa_chain.invoke({"question": question, "chat_history": []})
55
- return response['answer']
56
 
57
  # Gradio UI
58
  with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none !important;}") as demo:
59
  gr.Markdown("## πŸŽ™οΈ **Audio Intelligence Assistant**")
60
- gr.Markdown("Upload an audio file, get the transcript, and ask questions about the content!")
61
 
62
  with gr.Row():
63
  with gr.Column(scale=1):
@@ -66,7 +74,7 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none !importan
66
  status_output = gr.Textbox(label="πŸ› οΈ Status", interactive=False)
67
  transcript_output = gr.Textbox(label="πŸ“ Transcript", lines=10, interactive=False)
68
  with gr.Column(scale=1):
69
- question_input = gr.Textbox(label="❓ Ask a question about the audio", placeholder="What is the audio about?")
70
  ask_button = gr.Button("πŸ’¬ Ask")
71
  answer_output = gr.Textbox(label="πŸ€– Answer", lines=5)
72
 
@@ -83,3 +91,4 @@ with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none !importan
83
  )
84
 
85
  demo.launch()
 
 
9
  from langchain.memory import ConversationBufferMemory
10
  from langchain_community.llms import HuggingFaceEndpoint
11
 
12
+ # Load Whisper model (you can use "base", "small", "medium", or "large")
13
  model = whisper.load_model("tiny")
14
 
15
+ # Model config for Hugging Face Inference API
16
+ hub = {
17
+ "HF_MODEL_ID": "mistralai/Mistral-7B-Instruct-v0.2", # Must be Inference API compatible
18
+ "HF_TASK": "text-generation",
19
+ "HF_API_TOKEN": os.environ["HUGGING_FACE_READ_TOKEN"]
20
+ }
21
+
22
+ # Global state
23
  vector_db = None
24
  qa_chain = None
25
 
26
+ # Function to transcribe and initialize RAG pipeline
27
  def transcribe_and_setup(audio_file_path):
28
  global vector_db, qa_chain
29
 
30
  if audio_file_path is None:
31
  return "No audio uploaded.", None, None, ""
32
 
33
+ # Transcribe with Whisper
34
  result = model.transcribe(audio_file_path)
35
+ transcript = result["text"]
36
 
37
+ # Split and embed transcript
38
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
39
  splits = text_splitter.create_documents([transcript])
40
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
41
  vector_db = FAISS.from_documents(splits, embeddings)
42
 
43
+ # Create retriever + LLM QA chain
44
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
45
  retriever = vector_db.as_retriever()
46
  llm = HuggingFaceEndpoint(
47
+ repo_id=hub["HF_MODEL_ID"],
48
+ task=hub["HF_TASK"],
49
+ huggingfacehub_api_token=hub["HF_API_TOKEN"],
50
  temperature=0.5,
51
+ max_new_tokens=512
 
52
  )
53
  qa_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)
54
 
 
58
  def answer_question(question):
59
  global qa_chain
60
  if qa_chain is None:
61
+ return "Please upload and process an audio file first."
62
  response = qa_chain.invoke({"question": question, "chat_history": []})
63
+ return response["answer"]
64
 
65
  # Gradio UI
66
  with gr.Blocks(theme=gr.themes.Monochrome(), css="footer {display:none !important;}") as demo:
67
  gr.Markdown("## πŸŽ™οΈ **Audio Intelligence Assistant**")
68
+ gr.Markdown("Upload an audio file, get the transcript, and ask questions about its content!")
69
 
70
  with gr.Row():
71
  with gr.Column(scale=1):
 
74
  status_output = gr.Textbox(label="πŸ› οΈ Status", interactive=False)
75
  transcript_output = gr.Textbox(label="πŸ“ Transcript", lines=10, interactive=False)
76
  with gr.Column(scale=1):
77
+ question_input = gr.Textbox(label="❓ Ask a question about the audio", placeholder="e.g., What was discussed?")
78
  ask_button = gr.Button("πŸ’¬ Ask")
79
  answer_output = gr.Textbox(label="πŸ€– Answer", lines=5)
80
 
 
91
  )
92
 
93
  demo.launch()
94
+