| | import streamlit as st |
| | import torch |
| | from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification |
| | import torchaudio |
| | import os |
| | import re |
| | from difflib import SequenceMatcher |
| | import numpy as np |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | |
| | MODEL_NAME = "alvanlii/whisper-small-cantonese" |
| | language = "zh" |
| | pipe = pipeline( |
| | task="automatic-speech-recognition", |
| | model=MODEL_NAME, |
| | chunk_length_s=30, |
| | device=device, |
| | generate_kwargs={ |
| | "no_repeat_ngram_size": 3, |
| | "repetition_penalty": 1.3, |
| | "temperature": 0.7, |
| | "top_p": 0.9, |
| | "top_k": 50 |
| | } |
| | ) |
| | pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=language, task="transcribe") |
| |
|
| | |
| | def is_similar(a, b, threshold=0.8): |
| | return SequenceMatcher(None, a, b).ratio() > threshold |
| |
|
| | def remove_repeated_phrases(text): |
| | sentences = re.split(r'(?<=[ใ๏ผ๏ผ])', text) |
| | cleaned_sentences = [] |
| | for sentence in sentences: |
| | if not cleaned_sentences or not is_similar(sentence.strip(), cleaned_sentences[-1].strip()): |
| | cleaned_sentences.append(sentence.strip()) |
| | return " ".join(cleaned_sentences) |
| |
|
| | |
| | def remove_punctuation(text): |
| | return re.sub(r'[^\w\s]', '', text) |
| |
|
| | |
| | def transcribe_audio(audio_path): |
| | waveform, sample_rate = torchaudio.load(audio_path) |
| |
|
| | if waveform.shape[0] > 1: |
| | waveform = torch.mean(waveform, dim=0, keepdim=True) |
| |
|
| | waveform = waveform.squeeze(0).numpy() |
| |
|
| | duration = waveform.shape[0] / sample_rate |
| | if duration > 60: |
| | chunk_size = sample_rate * 55 |
| | step_size = sample_rate * 50 |
| | results = [] |
| |
|
| | for start in range(0, waveform.shape[0], step_size): |
| | chunk = waveform[start:start + chunk_size] |
| | if chunk.shape[0] == 0: |
| | break |
| | transcript = pipe({"sampling_rate": sample_rate, "raw": chunk})["text"] |
| | results.append(remove_punctuation(transcript)) |
| |
|
| | return remove_punctuation(remove_repeated_phrases(" ".join(results))) |
| |
|
| | return remove_punctuation(remove_repeated_phrases(pipe({"sampling_rate": sample_rate, "raw": waveform})["text"])) |
| |
|
| | |
| | sentiment_pipe = pipeline("text-classification", model="MonkeyDLLLLLLuffy/CustomModel-multilingual-sentiment-analysis-enhanced", device=device) |
| |
|
| | |
| | def rate_quality(text): |
| | chunks = [text[i:i+512] for i in range(0, len(text), 512)] |
| | results = sentiment_pipe(chunks, batch_size=4) |
| |
|
| | label_map = {"Very Negative": "Very Poor", "Negative": "Poor", "Neutral": "Neutral", "Positive": "Good", "Very Positive": "Very Good"} |
| | processed_results = [label_map.get(res["label"], "Unknown") for res in results] |
| |
|
| | return max(set(processed_results), key=processed_results.count) |
| |
|
| | |
| | def main(): |
| | st.set_page_config(page_title="Customer Service Analyzer", page_icon="๐๏ธ") |
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | .header { |
| | background: linear-gradient(90deg, #4B79A1, #283E51); |
| | border-radius: 10px; |
| | padding: 1.5rem; |
| | text-align: center; |
| | box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); |
| | margin-bottom: 1.5rem; |
| | color: white; |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | st.markdown(""" |
| | <div class="header"> |
| | <h1 style='margin:0;'>๐๏ธ Customer Service Quality Analyzer</h1> |
| | <p>Evaluate the service quality with simple uploading!</p> |
| | </div> |
| | """, unsafe_allow_html=True) |
| |
|
| | uploaded_file = st.file_uploader("๐ค Please upload your Cantonese customer service audio file", type=["wav", "mp3", "flac"]) |
| |
|
| | if uploaded_file is not None: |
| | temp_audio_path = "uploaded_audio.wav" |
| | with open(temp_audio_path, "wb") as f: |
| | f.write(uploaded_file.getbuffer()) |
| |
|
| | st.audio(uploaded_file, format="audio/wav") |
| |
|
| | with st.spinner('๐ Processing your audio, please wait...'): |
| | transcript = transcribe_audio(temp_audio_path) |
| | quality_rating = rate_quality(transcript) |
| |
|
| | st.write("**Transcript:**", transcript) |
| | st.write("**Sentiment Analysis Result:**", quality_rating) |
| |
|
| | result_text = f"Transcript:\n{transcript}\n\nSentiment Analysis Result: {quality_rating}" |
| | st.download_button(label="๐ฅ Download Analysis Report", data=result_text, file_name="analysis_report.txt") |
| |
|
| | st.markdown("โIf you encounter any issues, please contact customer support: ๐ง **support@hellotoby.com**") |
| |
|
| | os.remove(temp_audio_path) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|