Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import whisper | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from gtts import gTTS | |
| import os | |
| # Hugging Face Token (if using a private model) | |
| HF_AUTH_TOKEN = "" # Replace with your token if needed; leave empty for public models | |
| # Load Whisper Model | |
| def load_whisper_model(): | |
| return whisper.load_model("base") | |
| # Load Llama-2 Model | |
| def load_llama_model(): | |
| model_name = "meta-llama/Llama-2-7b-chat-hf" # Official Llama-2 model from Meta | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=HF_AUTH_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=HF_AUTH_TOKEN, torch_dtype="auto") | |
| return tokenizer, model | |
| # Initialize models | |
| whisper_model = load_whisper_model() | |
| llama_tokenizer, llama_model = load_llama_model() | |
| # Streamlit App | |
| def main(): | |
| st.title("Audio Query App with Llama-2 and Whisper") | |
| # File upload | |
| uploaded_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "m4a"]) | |
| if uploaded_file is not None: | |
| # Save the audio file locally | |
| input_audio_path = "input_audio.wav" | |
| with open(input_audio_path, "wb") as f: | |
| f.write(uploaded_file.read()) | |
| st.audio(input_audio_path, format="audio/wav") | |
| # Step 1: Transcribe audio | |
| with st.spinner("Transcribing audio..."): | |
| transcription = whisper_model.transcribe(input_audio_path)["text"] | |
| st.write(f"**Transcription:** {transcription}") | |
| # Step 2: Generate response using Llama-2 | |
| with st.spinner("Generating response..."): | |
| inputs = llama_tokenizer(transcription, return_tensors="pt") | |
| outputs = llama_model.generate(**inputs, max_length=150) | |
| response_text = llama_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| st.write(f"**Response:** {response_text}") | |
| # Step 3: Convert text response to audio | |
| with st.spinner("Converting response to audio..."): | |
| response_audio_path = "response_audio.mp3" | |
| tts = gTTS(text=response_text, lang="en") | |
| tts.save(response_audio_path) | |
| st.audio(response_audio_path, format="audio/mp3") | |
| if __name__ == "__main__": | |
| main() | |