Spaces:
Build error
Build error
| import gradio as gr | |
| from audio_processing import process_audio | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
| import spaces | |
| import torch | |
| import logging | |
| import traceback | |
| import sys | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout) | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def load_translation_model() : | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
| return model, tokenizer | |
| def alternate_translation(translation_model, translation_tokenizer, inputs): | |
| # model, tokenizer = load_translation_model() | |
| tokenized_inputs = translation_tokenizer(inputs, return_tensors='pt') | |
| answer = "" | |
| # for | |
| translated_tokens = translation_model.generate(**tokenized_inputs, forced_bos_token_id=translation_tokenizer.convert_tokens_to_ids("eng_Latn"), max_length=100) | |
| return translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0] | |
| def load_qa_model(): | |
| logger.info("Loading Q&A model...") | |
| try: | |
| model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| qa_pipeline = pipeline( | |
| "text-generation", | |
| model=model_id, | |
| model_kwargs={"torch_dtype": torch.bfloat16}, | |
| device_map="auto", | |
| ) | |
| logger.info(f"Q&A model loaded successfully") | |
| return qa_pipeline | |
| except Exception as e: | |
| logger.warning(f"Failed to load Q&A model. Error: \n{str(e)}") | |
| return None | |
| def load_summarization_model(): | |
| logger.info("Loading summarization model...") | |
| try: | |
| cuda_available = torch.cuda.is_available() | |
| summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=0 if cuda_available else -1) | |
| logger.info(f"Summarization model loaded successfully on {'GPU' if cuda_available else 'CPU'}") | |
| return summarizer | |
| except Exception as e: | |
| logger.warning(f"Failed to load summarization model on GPU. Falling back to CPU. Error: \n{str(e)}") | |
| summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6", device=-1) | |
| logger.info("Summarization model loaded successfully on CPU") | |
| return summarizer | |
| def process_with_fallback(func, *args, **kwargs): | |
| try: | |
| return func(*args, **kwargs) | |
| except Exception as e: | |
| logger.error(f"Error during processing: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| if "CUDA" in str(e) or "GPU" in str(e): | |
| logger.info("Falling back to CPU processing...") | |
| kwargs['use_gpu'] = False | |
| return func(*args, **kwargs) | |
| else: | |
| raise | |
| def transcribe_audio(audio_file, translate, model_size): | |
| logger.info(f"Starting transcription: translate={translate}, model_size={model_size}") | |
| try: | |
| result = process_with_fallback(process_audio, audio_file, translate=translate, model_size=model_size) # use_diarization=use_diarization | |
| logger.info("Transcription completed successfully") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Transcription failed: {str(e)}") | |
| raise gr.Error(f"Transcription failed: {str(e)}") | |
| def summarize_text(text): | |
| logger.info("Starting text summarization") | |
| try: | |
| summarizer = load_summarization_model() | |
| summary = summarizer(text, max_length=150, min_length=50, do_sample=False)[0]['summary_text'] | |
| logger.info("Summarization completed successfully") | |
| return summary | |
| except Exception as e: | |
| logger.error(f"Summarization failed: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return "Error occurred during summarization. Please try again." | |
| def process_and_summarize(audio_file, translate, model_size, do_summarize=True): | |
| logger.info(f"Starting process_and_summarize: translate={translate}, model_size={model_size}, do_summarize={do_summarize}") | |
| try: | |
| language_segments, final_segments = transcribe_audio(audio_file, translate, model_size) | |
| translation_model, translation_tokenizer = load_translation_model() | |
| # transcription = "Detected language changes:\n\n" | |
| transcription = "" | |
| for segment in language_segments: | |
| transcription += f"Language: {segment['language']}\n" | |
| transcription += f"Time: {segment['start']:.2f}s - {segment['end']:.2f}s\n\n" | |
| transcription += f"Transcription with language detection and speaker diarization (using {model_size} model):\n\n" | |
| full_text = "" | |
| for segment in final_segments: | |
| transcription += f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:\n" | |
| transcription += f"Original: {segment['text']}\n" | |
| if translate: | |
| alt_trans=alternate_translation(translation_model, translation_tokenizer, segment['text']) | |
| transcription += f"Translated:{alt_trans}" | |
| full_text += alt_trans | |
| # transcription += f"Translated: {segment['translated']}\n" | |
| # full_text += segment['translated'] + " " | |
| else: | |
| full_text += segment['text'] + " " | |
| transcription += "\n" | |
| summary = summarize_text(full_text) if do_summarize else "" | |
| logger.info("Process and summarize completed successfully") | |
| return transcription, full_text, summary | |
| except Exception as e: | |
| logger.error(f"Process and summarize failed: {str(e)}\n") | |
| logger.error(traceback.format_exc()) | |
| raise gr.Error(f"Processing failed: {str(e)}") | |
| def answer_question(context, question): | |
| logger.info("Starting Q&A process") | |
| try: | |
| qa_pipeline = load_qa_model() | |
| if qa_pipeline is None: | |
| return "Error: Q&A model could not be loaded." | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant who can answer questions based on the given context."}, | |
| {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}, | |
| ] | |
| alternate_system_message = """ | |
| You are an AI assistant designed to analyze speech transcriptions in a safe and responsible manner. | |
| Your purpose is to assist people, not to monitor or detect threats. | |
| When responding to user queries, your primary goals are: | |
| 1. To provide factual, accurate information to the best of your abilities. | |
| 2. To guide users towards appropriate resources and authorities if they are facing an emergency or urgent situation. | |
| 3. To refrain from speculating about or escalating potentially concerning situations without clear justification. | |
| 4. To avoid making judgements or taking actions that could infringe on individual privacy or civil liberties. | |
| However, if the speech suggests someone may be in immediate danger or that a crime is being planned, you should: | |
| - Identify & report | |
| - Identify any cryptic information and report it. | |
| - Avoid probing for additional details or speculating about the nature of the potential threat. | |
| - Do not provide any information that could enable or encourage harmful, illegal or unethical acts. | |
| Your role is to be a helpful, informative assistant. | |
| """ | |
| out = qa_pipeline(messages, max_new_tokens=256) | |
| logger.info(f"Raw model output: {out}") | |
| generated_text = out[0]['generated_text'] | |
| # Find the assistant's message | |
| for message in generated_text: | |
| if message['role'] == 'assistant': | |
| answer = message['content'] | |
| break | |
| else: | |
| answer = "No assistant response found in the model's output." | |
| logger.info(f"Extracted answer: {answer}") | |
| return answer | |
| except Exception as e: | |
| logger.error(f"Q&A process failed: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return f"Error occurred during Q&A process. Please try again. Error: {str(e)}" | |
| # Main interface | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# WhisperX Audio Transcription, Translation, Summarization, and Q&A (with ZeroGPU support)") | |
| audio_input = gr.Audio(type="filepath") | |
| translate_checkbox = gr.Checkbox(label="Enable Translation") | |
| summarize_checkbox = gr.Checkbox(label="Enable Summarization", interactive=False) | |
| model_dropdown = gr.Dropdown(choices=["tiny", "base", "small", "medium", "large", "large-v2", "large-v3"], label="Whisper Model Size", value="small") | |
| process_button = gr.Button("Process Audio") | |
| transcription_output = gr.Textbox(label="Transcription/Translation") | |
| full_text_output = gr.Textbox(label="Full Text") | |
| summary_output = gr.Textbox(label="Summary") | |
| question_input = gr.Textbox(label="Ask a question about the transcription") | |
| answer_button = gr.Button("Get Answer") | |
| answer_output = gr.Textbox(label="Answer") | |
| translate_alternate = gr.Button("Alternate Translation") | |
| translate_alternate_output = gr.Textbox(label="Alternate Translation") | |
| def update_summarize_checkbox(translate): | |
| return gr.Checkbox(interactive=translate) | |
| translate_checkbox.change(update_summarize_checkbox, inputs=[translate_checkbox], outputs=[summarize_checkbox]) | |
| process_button.click( | |
| process_and_summarize, | |
| inputs=[audio_input, translate_checkbox, model_dropdown, summarize_checkbox], | |
| outputs=[transcription_output, full_text_output, summary_output] | |
| ) | |
| answer_button.click( | |
| answer_question, | |
| inputs=[full_text_output, question_input], | |
| outputs=[answer_output] | |
| ) | |
| translate_alternate.click( | |
| alternate_translation, | |
| inputs=[summary_output], | |
| outputs=[translate_alternate_output] | |
| ) | |
| gr.Markdown( | |
| f""" | |
| ## System Information | |
| - Device: {"CUDA" if torch.cuda.is_available() else "CPU"} | |
| - CUDA Available: {"Yes" if torch.cuda.is_available() else "No"} | |
| ## ZeroGPU Support | |
| This application supports ZeroGPU for Hugging Face Spaces pro users. | |
| GPU-intensive tasks are automatically optimized for better performance when available. | |
| """ | |
| ) | |
| iface.launch() |