import torch import torch.nn.functional as F import gradio as gr import librosa import numpy as np import cv2 import timm import os import time import spaces import plotly.express as px from huggingface_hub import hf_hub_download from transformers import ( AutoProcessor, AutoModelForImageTextToText, ASTFeatureExtractor, ASTForAudioClassification, AutoModelForCausalLM, AutoTokenizer ) from moviepy import VideoFileClip # --- Configuration --- CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention'] DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ========================================== # 1. CHARGEMENT DES MODÈLES # ========================================== def load_models(): print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...") # Modèle VLM vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct" vlm_model = AutoModelForImageTextToText.from_pretrained( vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ).to(DEVICE).eval() # LLM Juge llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct" llm_tok = AutoTokenizer.from_pretrained(llm_id) llm_model = AutoModelForCausalLM.from_pretrained( llm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ).to(DEVICE).eval() # Audio models audio_models = {} for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'), ('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]: path = hf_hub_download(repo_id=repo, filename=f) m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3) m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict']) audio_models[p] = m.to(DEVICE).eval() path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth") model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True) sd = torch.load(path_c, map_location=DEVICE)['model_state_dict'] model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False) audio_models['C'] = model_c.to(DEVICE).eval() audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") return vlm_model, llm_tok, llm_model, audio_models # Chargement global vlm_model, llm_tok, llm_model, audio_models = load_models() # ========================================== # 2. JUGE HYBRIDE (règles + LLM) # ========================================== def call_peace_judge(audio_ctx, vlm_desc): """ Deterministic + LLM hybrid judge. AUDIO dominates when confidence > 30%. Vision can refine but never neutralize strong audio signals. """ vlm_lower = vlm_desc.lower() audio_upper = audio_ctx.upper() # ===================================================== # 1. HARD AUDIO GUARDRAILS (ABSOLUTE PRIORITY) # ===================================================== if "PAIN" in audio_upper: return "The cat is in pain." if "ANGRY" in audio_upper: return "The cat is angry." if "DEFENSIVE" in audio_upper: return "The cat is defensive." if "BACK_OFF" in audio_upper or "BACKING_OFF" in audio_upper: return "The cat is backing off." # ===================================================== # 2. HARD VISUAL OVERRIDES (SAFETY FIRST) # ===================================================== # Aggression / threat display if any(x in vlm_lower for x in [ "front paws raised", "paws raised", "swiping", "hissing", "mouth open and tense" ]): return "The cat is angry." # Defensive posture if any(x in vlm_lower for x in [ "arched back", "puffed fur", "ears flat", "ears back", "sideways stance" ]): return "The cat is defensive." # Pain indicators if any(x in vlm_lower for x in [ "limping", "hunched", "crouched low", "guarding", "withdrawn posture" ]): return "The cat is in pain." # ===================================================== # 3. POSITIVE / LOW-RISK VISUAL STATES # ===================================================== if any(x in vlm_lower for x in [ "kneading", "rubbing", "head bumping" ]): return "The cat is affectionate." if any(x in vlm_lower for x in [ "playful", "rolling", "pouncing" ]): return "The cat is happy." if any(x in vlm_lower for x in [ "stalking", "tail twitching", "low crawl" ]): return "The cat is hunting." if any(x in vlm_lower for x in [ "approaching human", "following human", "pawing at leg" ]): return "The cat is wanting attention." if any(x in vlm_lower for x in [ "waiting posture", "looking at food", "pacing near bowl" ]): return "The cat is hungry." # ===================================================== # 4. LLM FALLBACK (NO CALM ALLOWED) # ===================================================== messages = [ { "role": "system", "content": ( "You are a strict cat behavior decision engine.\n" "Rules:\n" "1. AUDIO has priority over vision.\n" "2. You must choose the most conservative interpretation.\n" "3. 'calm' is NOT a valid output.\n" "4. If unsure, prefer defensive or in pain.\n\n" "Allowed outputs ONLY:\n" "affectionate, angry, backing off, defensive, hungry, happy, " "hunting, in heat, calling kittens, in pain, wanting attention\n\n" "Answer format EXACTLY:\n" "The cat is [label]." ) }, { "role": "user", "content": ( f"AUDIO SIGNAL (PRIMARY): {audio_ctx}\n" f"VISION OBSERVATIONS (SECONDARY): {vlm_desc}\n\n" "FINAL DECISION:" ) } ] input_text = llm_tok.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = llm_tok(input_text, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = llm_model.generate( **inputs, max_new_tokens=15, do_sample=False, temperature=0.0, pad_token_id=llm_tok.eos_token_id, eos_token_id=llm_tok.eos_token_id ) generated = llm_tok.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True ).lower() for cat in CATEGORIES: if cat.replace("_", " ") in generated: return f"The cat is {cat.replace('_', ' ')}." # ===================================================== # 5. FINAL FAILSAFE (NEVER CALM) # ===================================================== return "The cat is defensive." # ========================================== # 3. PIPELINE ANALYSE COMPLETE (CORRIGÉ) # ========================================== @spaces.GPU(duration=120) def analyze_cat_v12_final(video_path): if not video_path: return "❌ Aucune vidéo.", None if torch.cuda.is_available(): torch.cuda.empty_cache() tmp_audio = f"temp_{os.getpid()}_{int(time.time())}.wav" start_total = time.time() # -------------------------------------------------- # Helper: clean VLM repetitions (cheap & mobile-safe) # -------------------------------------------------- def clean_vlm_output(text): sentences = text.split(". ") cleaned = [] seen = set() for s in sentences: key = s.strip().lower() if key and key not in seen: seen.add(key) cleaned.append(s.strip()) return ". ".join(cleaned) try: # ========================= # A. AUDIO # ========================= t_0 = time.time() clip = VideoFileClip(video_path) audio_probs = np.zeros(len(CATEGORIES)) if clip.audio: clip.audio.write_audiofile(tmp_audio, fps=16000, logger=None) w, _ = librosa.load(tmp_audio, sr=16000, duration=5.0) if len(w) < 48000: w = np.pad(w, (0, 48000 - len(w))) mel = librosa.feature.melspectrogram(y=w, sr=16000, n_mels=192) mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40 img = cv2.resize( (np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8), (224, 224) ) img_t = ( torch.tensor(img) .unsqueeze(0) .repeat(1, 3, 1, 1) .float() .to(DEVICE) / 255.0 ) with torch.no_grad(): pa = F.softmax(audio_models['A'](img_t), dim=1) pb = F.softmax(audio_models['B'](img_t), dim=1) ic = audio_models['ast_ext']( w, sampling_rate=16000, return_tensors="pt" ).to(DEVICE) pc = F.softmax(audio_models['C'](**ic).logits, dim=1) audio_probs = ( pa * 0.3468 + pb * 0.2762 + pc * 0.3770 ).cpu().numpy()[0] clip.close() t_audio = time.time() - t_0 # ========================= # B. VISION (VLM STABILISÉ) # ========================= t_1 = time.time() vlm_proc = AutoProcessor.from_pretrained( "HuggingFaceTB/SmolVLM2-256M-Video-Instruct" ) vlm_prompt = ( "You are a feline behavior expert.\n" "Describe ONLY observable physical features:\n" "- ears position\n" "- mouth state (open/closed/tense)\n" "- tail position or movement\n" "- body posture\n" "Use short factual sentences.\n" "One observation per sentence.\n" "Do NOT interpret mood." ) messages = [ { "role": "user", "content": [ {"type": "video", "path": video_path}, {"type": "text", "text": vlm_prompt} ] } ] vlm_inputs = vlm_proc.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(DEVICE) input_length = vlm_inputs["input_ids"].shape[1] with torch.no_grad(): vlm_out = vlm_model.generate( **vlm_inputs, max_new_tokens=80, do_sample=False, temperature=0.0, repetition_penalty=1.15, # 🔑 anti-loop no_repeat_ngram_size=5, # 🔑 anti-phrase répétée pad_token_id=vlm_proc.tokenizer.eos_token_id, eos_token_id=vlm_proc.tokenizer.eos_token_id ) gen_tokens = vlm_out[0][input_length:] vlm_clean = vlm_proc.batch_decode( [gen_tokens], skip_special_tokens=True )[0] vlm_clean = vlm_clean.strip().split("\n")[0] if vlm_clean.lower().startswith("assistant:"): vlm_clean = vlm_clean.split(":", 1)[-1].strip() # nettoyage final anti-répétition vlm_clean = clean_vlm_output(vlm_clean) if torch.cuda.is_available(): torch.cuda.empty_cache() t_vlm = time.time() - t_1 # ========================= # C. JUGE # ========================= t_2 = time.time() top_idx = np.argmax(audio_probs) audio_ctx = f"{CATEGORIES[top_idx].upper()} ({audio_probs[top_idx]*100:.1f}%)" judge_decision = call_peace_judge(audio_ctx, vlm_clean) t_llm = time.time() - t_2 # ========================= # D. VISUELS # ========================= top5 = np.argsort(audio_probs)[-5:][::-1] fig = px.bar( x=[audio_probs[i] * 100 for i in top5], y=[CATEGORIES[i].upper() for i in top5], orientation="h", title="Top 5 Scores Audio", labels={"x": "Probabilité (%)", "y": "Émotion"}, color=[audio_probs[i] * 100 for i in top5], color_continuous_scale="Viridis" ) fig.update_layout(height=400, showlegend=False) # ========================= # E. RAPPORT FINAL # ========================= t_total = time.time() - start_total report = f"""⚖️ VERDICT JUGE : {judge_decision} ------------------------------------------ 👁️ VISION : {vlm_clean} 📊 AUDIO : {audio_ctx} ⏱️ TEMPS : Audio {t_audio:.2f}s | Vision {t_vlm:.2f}s | Juge {t_llm:.2f}s | Total {t_total:.2f}s""" return report, fig except Exception as e: return f"❌ Erreur : {str(e)}", None finally: if os.path.exists(tmp_audio): try: os.remove(tmp_audio) except: pass # --- Interface Gradio --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode") gr.Markdown("✅ **SmolVLM2-256M** + **SmolLM2-135M Juge** + Audio Ensemble") with gr.Row(): with gr.Column(): video_input = gr.Video(label="Vidéo du chat") btn = gr.Button("🚀 ANALYSER", variant="primary", size="lg") with gr.Column(): report_out = gr.Textbox(label="Résultat complet", lines=12, interactive=False) chart_out = gr.Plot(label="Distribution des émotions (Audio)") btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out]) demo.launch()