File size: 3,842 Bytes
4b98922
 
 
 
 
 
 
 
 
 
 
74c3d6f
0ec659f
4b98922
b58d41a
 
e315515
b58d41a
74c3d6f
b58d41a
 
49807da
4b98922
 
49807da
4b98922
49807da
30410e3
74c3d6f
4b98922
30410e3
49807da
4b98922
30410e3
b58d41a
4b98922
49807da
4b98922
 
 
 
 
49807da
74c3d6f
4b98922
 
b58d41a
 
 
4b98922
b58d41a
4b98922
74c3d6f
4b98922
49807da
4b98922
49807da
74c3d6f
 
4b98922
74c3d6f
 
49807da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b98922
49807da
 
74c3d6f
 
4b98922
 
 
30410e3
74c3d6f
4b98922
74c3d6f
49807da
74c3d6f
 
4b98922
 
 
 
49807da
4b98922
 
 
 
74c3d6f
49807da
4b98922
 
49807da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import whisper
import os
import tempfile
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint

# Load Whisper model (change to "base", "small", or "medium" for better accuracy)
model = whisper.load_model("base")

# Model config for Hugging Face Inference API
hub = {
    "HF_MODEL_ID": "mistralai/Mistral-7B-Instruct-v0.3",  
    "HF_TASK": "text-generation",
    "HF_API_TOKEN": os.environ.get("HUGGINGFACE_API_TOKEN")
}

# Global state
vector_db = None
qa_chain = None
chat_memory = None

# Transcribe and set up RAG
def transcribe_and_setup(audio_file_path):
    global vector_db, qa_chain, chat_memory

    if audio_file_path is None:
        return "No audio uploaded.", []

    result = model.transcribe(audio_file_path)
    transcript = result["text"]

    # Split and embed
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
    splits = text_splitter.create_documents([transcript])
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_db = FAISS.from_documents(splits, embeddings)

    # QA setup
    chat_memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    retriever = vector_db.as_retriever()
    llm = HuggingFaceEndpoint(
        repo_id=hub["HF_MODEL_ID"],
        task=hub["HF_TASK"],
        huggingfacehub_api_token=hub["HF_API_TOKEN"],
        temperature=0.5,
        max_new_tokens=512
    )
    qa_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=chat_memory)

    return "Transcription complete! Ready for questions.", []  # Empty chat reset

# Handle conversation
def answer_question(question, chat_history):
    global qa_chain, chat_memory
    if qa_chain is None:
        return "Please upload and process an audio file first.", chat_history

    response = qa_chain.invoke({
        "question": question,
        "chat_history": chat_memory.load_memory_variables({})["chat_history"]
    })

    # Just show back-and-forth messages
    chat_history.append([question, response["answer"]])
    return "", chat_history

# Custom CSS
custom_css = """
.chatbox .message.user, .chatbox .message.bot {
    background-color: #1e3d2f !important;
    color: #ffffff !important;
    border-radius: 10px !important;
    padding: 10px !important;
    margin: 5px !important;
}
.chatbox .message.user::before, .chatbox .message.bot::before {
    content: none !important;
}
"""

# Gradio app
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
    gr.Markdown("## πŸŽ™οΈ **Conversational Audio Chatbot**")
    gr.Markdown("Upload an audio file, let the AI process it, and ask any questions!")

    with gr.Row():
        with gr.Column(scale=1):
            audio_input = gr.Audio(type="filepath", label="🎧 Upload Audio")
            transcribe_button = gr.Button("πŸš€ Process Audio")
            status_output = gr.Textbox(label="πŸ› οΈ Status", interactive=False)
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(label="πŸ’¬ Chat with your audio", elem_classes=["chatbox"])
            question_input = gr.Textbox(label="Type your question", placeholder="Ask about the audio...")
            ask_button = gr.Button("πŸ’‘ Ask")

    transcribe_button.click(
        fn=transcribe_and_setup,
        inputs=audio_input,
        outputs=[status_output, chatbot]
    )

    ask_button.click(
        fn=answer_question,
        inputs=[question_input, chatbot],
        outputs=[question_input, chatbot]
    )

demo.launch()