Spaces:
Build error
Build error
Kevin King
REFAC: Update model loading to use 'tiny.en' for improved performance in Streamlit app
0dcbb44 | import os | |
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| import whisper | |
| from transformers import pipeline, AutoModelForAudioClassification, AutoFeatureExtractor | |
| from deepface import DeepFace | |
| import logging | |
| import soundfile as sf | |
| import tempfile | |
| import cv2 | |
| from moviepy.editor import VideoFileClip | |
| import time | |
| import pandas as pd | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import matplotlib.pyplot as plt | |
| # Create a cross-platform, writable cache directory for all libraries | |
| CACHE_DIR = os.path.join(tempfile.gettempdir(), "affectlink_cache") | |
| DEEPFACE_CACHE_PATH = os.path.join(CACHE_DIR, ".deepface", "weights") | |
| os.makedirs(DEEPFACE_CACHE_PATH, exist_ok=True) # Proactively create the full path | |
| os.environ['DEEPFACE_HOME'] = CACHE_DIR | |
| os.environ['HF_HOME'] = CACHE_DIR | |
| # --- Page Configuration --- | |
| st.set_page_config(page_title="AffectLink Demo", page_icon="😊", layout="wide") | |
| st.title("AffectLink: Post-Hoc Emotion Analysis") | |
| st.write("Upload a short video clip (under 30 seconds) to see a multimodal emotion analysis.") | |
| # --- Logger Configuration --- | |
| # [Logger setup remains the same] | |
| logging.basicConfig(level=logging.INFO) | |
| logging.getLogger('deepface').setLevel(logging.ERROR) | |
| logging.getLogger('huggingface_hub').setLevel(logging.WARNING) | |
| logging.getLogger('moviepy').setLevel(logging.ERROR) | |
| # --- Emotion Mappings --- | |
| UNIFIED_EMOTIONS = ['angry', 'happy', 'sad', 'neutral'] | |
| TEXT_TO_UNIFIED = {'neutral': 'neutral', 'joy': 'happy', 'sadness': 'sad', 'anger': 'angry'} | |
| SER_TO_UNIFIED = {'neu': 'neutral', 'hap': 'happy', 'sad': 'sad', 'ang': 'angry'} | |
| FACIAL_TO_UNIFIED = {'neutral': 'neutral', 'happy': 'happy', 'sad': 'sad', 'angry': 'angry', 'fear':None, 'surprise':None, 'disgust':None} | |
| AUDIO_SAMPLE_RATE = 16000 | |
| # --- Model Loading --- | |
| def load_models(): | |
| with st.spinner("Loading AI models, this may take a moment..."): | |
| whisper_model = whisper.load_model("tiny.en", download_root=os.path.join(CACHE_DIR, "whisper")) | |
| text_classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", top_k=None) | |
| ser_model_name = "superb/hubert-large-superb-er" | |
| ser_feature_extractor = AutoFeatureExtractor.from_pretrained(ser_model_name) | |
| ser_model = AutoModelForAudioClassification.from_pretrained(ser_model_name) | |
| return whisper_model, text_classifier, ser_model, ser_feature_extractor | |
| whisper_model, text_classifier, ser_model, ser_feature_extractor = load_models() | |
| # --- Helper Functions for Analysis --- | |
| def create_unified_vector(scores_dict): | |
| vector = np.zeros(len(UNIFIED_EMOTIONS)) | |
| total_score = sum(scores_dict.values()) | |
| if total_score > 0: | |
| for label, score in scores_dict.items(): | |
| if label in UNIFIED_EMOTIONS: | |
| vector[UNIFIED_EMOTIONS.index(label)] = score / total_score | |
| return vector | |
| def get_consistency_level(cosine_sim): | |
| if np.isnan(cosine_sim): return "N/A" | |
| if cosine_sim >= 0.8: return "High" | |
| if cosine_sim >= 0.6: return "Medium" | |
| if cosine_sim >= 0.3: return "Low" | |
| return "Very Low" | |
| # --- UI and Processing Logic --- | |
| uploaded_file = st.file_uploader("Choose a video file...", type=["mp4", "mov", "avi", "mkv"]) | |
| if uploaded_file is not None: | |
| temp_video_path = None | |
| video_clip_for_duration = None | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile: | |
| tfile.write(uploaded_file.read()) | |
| temp_video_path = tfile.name | |
| st.video(temp_video_path) | |
| if st.button("Analyze Video"): | |
| fer_timeline, ser_timeline, ter_timeline = {}, {}, {} | |
| full_transcription = "No speech detected." | |
| video_clip_for_duration = VideoFileClip(temp_video_path) | |
| duration = video_clip_for_duration.duration | |
| with st.spinner("Analyzing facial expressions..."): | |
| cap = cv2.VideoCapture(temp_video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| frame_count = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: break | |
| timestamp = frame_count / fps | |
| if frame_count % int(fps) == 0: | |
| analysis = DeepFace.analyze(frame, actions=['emotion'], enforce_detection=False, silent=True) | |
| if isinstance(analysis, list) and len(analysis) > 0: | |
| fer_timeline[timestamp] = {k: v / 100.0 for k, v in analysis[0]['emotion'].items()} | |
| frame_count += 1 | |
| cap.release() | |
| with st.spinner("Analyzing audio and text..."): | |
| if video_clip_for_duration.audio: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as taudio: | |
| video_clip_for_duration.audio.write_audiofile(taudio.name, fps=AUDIO_SAMPLE_RATE, logger=None) | |
| temp_audio_path = taudio.name | |
| whisper_result = whisper_model.transcribe( | |
| temp_audio_path, | |
| word_timestamps=True, | |
| fp16=False, | |
| condition_on_previous_text=False | |
| ) | |
| full_transcription = whisper_result['text'].strip() | |
| audio_array, _ = sf.read(temp_audio_path, dtype='float32') | |
| if audio_array.ndim == 2: audio_array = audio_array.mean(axis=1) | |
| for i in range(int(duration)): | |
| start_sample, end_sample = i * AUDIO_SAMPLE_RATE, (i + 1) * AUDIO_SAMPLE_RATE | |
| chunk = audio_array[start_sample:end_sample] | |
| if len(chunk) > 400: | |
| inputs = ser_feature_extractor(chunk, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| logits = ser_model(**inputs).logits | |
| scores = torch.nn.functional.softmax(logits, dim=1).squeeze() | |
| ser_timeline[i] = {ser_model.config.id2label[k]: score.item() for k, score in enumerate(scores)} | |
| words_in_segment = [seg['word'] for seg in whisper_result.get('segments', []) if seg['start'] >= i and seg['start'] < i+1 for seg in seg.get('words', [])] | |
| segment_text = " ".join(words_in_segment).strip() | |
| if segment_text: | |
| text_emotions = text_classifier(segment_text)[0] | |
| ter_timeline[i] = {emo['label']: emo['score'] for emo in text_emotions} | |
| st.header("Analysis Results") | |
| def process_and_get_dominant(timeline, mapping): | |
| if not timeline: return "N/A", {} | |
| df = pd.DataFrame.from_dict(timeline, orient='index') | |
| unified_scores = {e: 0.0 for e in UNIFIED_EMOTIONS} | |
| for raw_label, scores in df.items(): | |
| unified_label = mapping.get(raw_label) | |
| if unified_label: | |
| unified_scores[unified_label] += scores.mean() | |
| if sum(unified_scores.values()) == 0: return "N/A", {} | |
| dominant_emotion = max(unified_scores, key=unified_scores.get) | |
| return dominant_emotion.capitalize(), unified_scores | |
| dominant_fer, fer_avg_scores = process_and_get_dominant(fer_timeline, FACIAL_TO_UNIFIED) | |
| dominant_ser, ser_avg_scores = process_and_get_dominant(ser_timeline, SER_TO_UNIFIED) | |
| dominant_text, ter_avg_scores = process_and_get_dominant(ter_timeline, TEXT_TO_UNIFIED) | |
| fer_vector = create_unified_vector(fer_avg_scores) | |
| ser_vector = create_unified_vector(ser_avg_scores) | |
| text_vector = create_unified_vector(ter_avg_scores) | |
| similarities = [cosine_similarity([fer_vector], [text_vector])[0][0], cosine_similarity([fer_vector], [ser_vector])[0][0], cosine_similarity([ser_vector], [text_vector])[0][0]] | |
| avg_similarity = np.nanmean([s for s in similarities if not np.isnan(s)]) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| st.subheader("Multimodal Summary") | |
| st.write(f"**Transcription:** \"{full_transcription}\"") | |
| st.metric("Dominant Facial Emotion", dominant_fer) | |
| st.metric("Dominant Text Emotion", dominant_text) | |
| st.metric("Dominant Speech Emotion", dominant_ser) | |
| st.metric("Emotion Consistency", get_consistency_level(avg_similarity), f"{avg_similarity:.2f} Avg. Cosine Similarity") | |
| with col2: | |
| st.subheader("Unified Emotion Timeline") | |
| def create_timeline_df(timeline, mapping): | |
| if not timeline: return pd.DataFrame(columns=UNIFIED_EMOTIONS) | |
| df = pd.DataFrame.from_dict(timeline, orient='index') | |
| df_unified = pd.DataFrame(index=df.index, columns=UNIFIED_EMOTIONS).fillna(0.0) | |
| for raw_col in df.columns: | |
| unified_col = mapping.get(raw_col) | |
| if unified_col: | |
| df_unified[unified_col] += df[raw_col] | |
| return df_unified | |
| fer_df = create_timeline_df(fer_timeline, FACIAL_TO_UNIFIED) | |
| ser_df = create_timeline_df(ser_timeline, SER_TO_UNIFIED) | |
| ter_df = create_timeline_df(ter_timeline, TEXT_TO_UNIFIED) | |
| full_index = np.arange(0, duration, 0.5) | |
| combined_df = pd.DataFrame(index=full_index) | |
| if not fer_df.empty: | |
| fer_df_resampled = fer_df.reindex(fer_df.index.union(full_index)).interpolate(method='linear').reindex(full_index) | |
| for e in UNIFIED_EMOTIONS: combined_df[f'Facial_{e}'] = fer_df_resampled.get(e, 0.0) | |
| if not ser_df.empty: | |
| ser_df_resampled = ser_df.reindex(ser_df.index.union(full_index)).interpolate(method='linear').reindex(full_index) | |
| for e in UNIFIED_EMOTIONS: combined_df[f'Speech_{e}'] = ser_df_resampled.get(e, 0.0) | |
| if not ter_df.empty: | |
| ter_df_resampled = ter_df.reindex(ter_df.index.union(full_index)).interpolate(method='linear').reindex(full_index) | |
| for e in UNIFIED_EMOTIONS: combined_df[f'Text_{e}'] = ter_df_resampled.get(e, 0.0) | |
| combined_df.fillna(0, inplace=True) | |
| if not combined_df.empty: | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| colors = {'happy': 'green', 'sad': 'blue', 'angry': 'red', 'neutral': 'gray'} | |
| styles = {'Facial': '-', 'Speech': '--', 'Text': ':'} | |
| for col in combined_df.columns: | |
| modality, emotion = col.split('_') | |
| if emotion in colors: | |
| ax.plot(combined_df.index, combined_df[col], label=f'{modality} {emotion.capitalize()}', color=colors[emotion], linestyle=styles[modality], alpha=0.8) | |
| ax.set_title("Emotion Confidence Over Time (Normalized)") | |
| ax.set_xlabel("Time (seconds)") | |
| ax.set_ylabel("Confidence Score (0-1)") | |
| ax.set_ylim(0, 1) | |
| ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) | |
| ax.grid(True, which='both', linestyle='--', linewidth=0.5) | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| else: | |
| st.write("No emotion data available to plot.") | |
| finally: | |
| if 'video_clip_for_duration' in locals() and video_clip_for_duration: video_clip_for_duration.close() | |
| if 'temp_audio_path' in locals() and temp_audio_path and os.path.exists(temp_audio_path): os.unlink(temp_audio_path) | |
| if temp_video_path and os.path.exists(temp_video_path): | |
| time.sleep(1) | |
| try: | |
| os.unlink(temp_video_path) | |
| except Exception: | |
| pass |