crazycat / app.py
ericjedha's picture
Update app.py
0642309 verified
import torch
import torch.nn.functional as F
import gradio as gr
import librosa
import numpy as np
import cv2
import timm
import os
import time
import spaces
import plotly.express as px
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
ASTFeatureExtractor,
ASTForAudioClassification,
AutoModelForCausalLM,
AutoTokenizer
)
from moviepy import VideoFileClip
# --- Configuration ---
CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================================
# 1. CHARGEMENT DES MODÈLES
# ==========================================
def load_models():
print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...")
# Modèle VLM
vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
vlm_model = AutoModelForImageTextToText.from_pretrained(
vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(DEVICE).eval()
# LLM Juge
llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
llm_tok = AutoTokenizer.from_pretrained(llm_id)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(DEVICE).eval()
# Audio models
audio_models = {}
for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
path = hf_hub_download(repo_id=repo, filename=f)
m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3)
m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
audio_models[p] = m.to(DEVICE).eval()
path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth")
model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
audio_models['C'] = model_c.to(DEVICE).eval()
audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
return vlm_model, llm_tok, llm_model, audio_models
# Chargement global
vlm_model, llm_tok, llm_model, audio_models = load_models()
# ==========================================
# 2. JUGE HYBRIDE (règles + LLM)
# ==========================================
def call_peace_judge(audio_ctx, vlm_desc):
"""
Deterministic + LLM hybrid judge.
AUDIO dominates when confidence > 30%.
Vision can refine but never neutralize strong audio signals.
"""
vlm_lower = vlm_desc.lower()
audio_upper = audio_ctx.upper()
# =====================================================
# 1. HARD AUDIO GUARDRAILS (ABSOLUTE PRIORITY)
# =====================================================
if "PAIN" in audio_upper:
return "The cat is in pain."
if "ANGRY" in audio_upper:
return "The cat is angry."
if "DEFENSIVE" in audio_upper:
return "The cat is defensive."
if "BACK_OFF" in audio_upper or "BACKING_OFF" in audio_upper:
return "The cat is backing off."
# =====================================================
# 2. HARD VISUAL OVERRIDES (SAFETY FIRST)
# =====================================================
# Aggression / threat display
if any(x in vlm_lower for x in [
"front paws raised", "paws raised", "swiping",
"hissing", "mouth open and tense"
]):
return "The cat is angry."
# Defensive posture
if any(x in vlm_lower for x in [
"arched back", "puffed fur", "ears flat",
"ears back", "sideways stance"
]):
return "The cat is defensive."
# Pain indicators
if any(x in vlm_lower for x in [
"limping", "hunched", "crouched low",
"guarding", "withdrawn posture"
]):
return "The cat is in pain."
# =====================================================
# 3. POSITIVE / LOW-RISK VISUAL STATES
# =====================================================
if any(x in vlm_lower for x in [
"kneading", "rubbing", "head bumping"
]):
return "The cat is affectionate."
if any(x in vlm_lower for x in [
"playful", "rolling", "pouncing"
]):
return "The cat is happy."
if any(x in vlm_lower for x in [
"stalking", "tail twitching", "low crawl"
]):
return "The cat is hunting."
if any(x in vlm_lower for x in [
"approaching human", "following human",
"pawing at leg"
]):
return "The cat is wanting attention."
if any(x in vlm_lower for x in [
"waiting posture", "looking at food",
"pacing near bowl"
]):
return "The cat is hungry."
# =====================================================
# 4. LLM FALLBACK (NO CALM ALLOWED)
# =====================================================
messages = [
{
"role": "system",
"content": (
"You are a strict cat behavior decision engine.\n"
"Rules:\n"
"1. AUDIO has priority over vision.\n"
"2. You must choose the most conservative interpretation.\n"
"3. 'calm' is NOT a valid output.\n"
"4. If unsure, prefer defensive or in pain.\n\n"
"Allowed outputs ONLY:\n"
"affectionate, angry, backing off, defensive, hungry, happy, "
"hunting, in heat, calling kittens, in pain, wanting attention\n\n"
"Answer format EXACTLY:\n"
"The cat is [label]."
)
},
{
"role": "user",
"content": (
f"AUDIO SIGNAL (PRIMARY): {audio_ctx}\n"
f"VISION OBSERVATIONS (SECONDARY): {vlm_desc}\n\n"
"FINAL DECISION:"
)
}
]
input_text = llm_tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = llm_tok(input_text, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = llm_model.generate(
**inputs,
max_new_tokens=15,
do_sample=False,
temperature=0.0,
pad_token_id=llm_tok.eos_token_id,
eos_token_id=llm_tok.eos_token_id
)
generated = llm_tok.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
).lower()
for cat in CATEGORIES:
if cat.replace("_", " ") in generated:
return f"The cat is {cat.replace('_', ' ')}."
# =====================================================
# 5. FINAL FAILSAFE (NEVER CALM)
# =====================================================
return "The cat is defensive."
# ==========================================
# 3. PIPELINE ANALYSE COMPLETE (CORRIGÉ)
# ==========================================
@spaces.GPU(duration=120)
def analyze_cat_v12_final(video_path):
if not video_path:
return "❌ Aucune vidéo.", None
if torch.cuda.is_available():
torch.cuda.empty_cache()
tmp_audio = f"temp_{os.getpid()}_{int(time.time())}.wav"
start_total = time.time()
# --------------------------------------------------
# Helper: clean VLM repetitions (cheap & mobile-safe)
# --------------------------------------------------
def clean_vlm_output(text):
sentences = text.split(". ")
cleaned = []
seen = set()
for s in sentences:
key = s.strip().lower()
if key and key not in seen:
seen.add(key)
cleaned.append(s.strip())
return ". ".join(cleaned)
try:
# =========================
# A. AUDIO
# =========================
t_0 = time.time()
clip = VideoFileClip(video_path)
audio_probs = np.zeros(len(CATEGORIES))
if clip.audio:
clip.audio.write_audiofile(tmp_audio, fps=16000, logger=None)
w, _ = librosa.load(tmp_audio, sr=16000, duration=5.0)
if len(w) < 48000:
w = np.pad(w, (0, 48000 - len(w)))
mel = librosa.feature.melspectrogram(y=w, sr=16000, n_mels=192)
mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40
img = cv2.resize(
(np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8),
(224, 224)
)
img_t = (
torch.tensor(img)
.unsqueeze(0)
.repeat(1, 3, 1, 1)
.float()
.to(DEVICE) / 255.0
)
with torch.no_grad():
pa = F.softmax(audio_models['A'](img_t), dim=1)
pb = F.softmax(audio_models['B'](img_t), dim=1)
ic = audio_models['ast_ext'](
w, sampling_rate=16000, return_tensors="pt"
).to(DEVICE)
pc = F.softmax(audio_models['C'](**ic).logits, dim=1)
audio_probs = (
pa * 0.3468 + pb * 0.2762 + pc * 0.3770
).cpu().numpy()[0]
clip.close()
t_audio = time.time() - t_0
# =========================
# B. VISION (VLM STABILISÉ)
# =========================
t_1 = time.time()
vlm_proc = AutoProcessor.from_pretrained(
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
)
vlm_prompt = (
"You are a feline behavior expert.\n"
"Describe ONLY observable physical features:\n"
"- ears position\n"
"- mouth state (open/closed/tense)\n"
"- tail position or movement\n"
"- body posture\n"
"Use short factual sentences.\n"
"One observation per sentence.\n"
"Do NOT interpret mood."
)
messages = [
{
"role": "user",
"content": [
{"type": "video", "path": video_path},
{"type": "text", "text": vlm_prompt}
]
}
]
vlm_inputs = vlm_proc.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(DEVICE)
input_length = vlm_inputs["input_ids"].shape[1]
with torch.no_grad():
vlm_out = vlm_model.generate(
**vlm_inputs,
max_new_tokens=80,
do_sample=False,
temperature=0.0,
repetition_penalty=1.15, # 🔑 anti-loop
no_repeat_ngram_size=5, # 🔑 anti-phrase répétée
pad_token_id=vlm_proc.tokenizer.eos_token_id,
eos_token_id=vlm_proc.tokenizer.eos_token_id
)
gen_tokens = vlm_out[0][input_length:]
vlm_clean = vlm_proc.batch_decode(
[gen_tokens], skip_special_tokens=True
)[0]
vlm_clean = vlm_clean.strip().split("\n")[0]
if vlm_clean.lower().startswith("assistant:"):
vlm_clean = vlm_clean.split(":", 1)[-1].strip()
# nettoyage final anti-répétition
vlm_clean = clean_vlm_output(vlm_clean)
if torch.cuda.is_available():
torch.cuda.empty_cache()
t_vlm = time.time() - t_1
# =========================
# C. JUGE
# =========================
t_2 = time.time()
top_idx = np.argmax(audio_probs)
audio_ctx = f"{CATEGORIES[top_idx].upper()} ({audio_probs[top_idx]*100:.1f}%)"
judge_decision = call_peace_judge(audio_ctx, vlm_clean)
t_llm = time.time() - t_2
# =========================
# D. VISUELS
# =========================
top5 = np.argsort(audio_probs)[-5:][::-1]
fig = px.bar(
x=[audio_probs[i] * 100 for i in top5],
y=[CATEGORIES[i].upper() for i in top5],
orientation="h",
title="Top 5 Scores Audio",
labels={"x": "Probabilité (%)", "y": "Émotion"},
color=[audio_probs[i] * 100 for i in top5],
color_continuous_scale="Viridis"
)
fig.update_layout(height=400, showlegend=False)
# =========================
# E. RAPPORT FINAL
# =========================
t_total = time.time() - start_total
report = f"""⚖️ VERDICT JUGE : {judge_decision}
------------------------------------------
👁️ VISION : {vlm_clean}
📊 AUDIO : {audio_ctx}
⏱️ TEMPS : Audio {t_audio:.2f}s | Vision {t_vlm:.2f}s | Juge {t_llm:.2f}s | Total {t_total:.2f}s"""
return report, fig
except Exception as e:
return f"❌ Erreur : {str(e)}", None
finally:
if os.path.exists(tmp_audio):
try:
os.remove(tmp_audio)
except:
pass
# --- Interface Gradio ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode")
gr.Markdown("✅ **SmolVLM2-256M** + **SmolLM2-135M Juge** + Audio Ensemble")
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Vidéo du chat")
btn = gr.Button("🚀 ANALYSER", variant="primary", size="lg")
with gr.Column():
report_out = gr.Textbox(label="Résultat complet", lines=12, interactive=False)
chart_out = gr.Plot(label="Distribution des émotions (Audio)")
btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out])
demo.launch()