File size: 14,214 Bytes
be73d82 bf64010 0642309 be73d82 455bb26 be73d82 0642309 be73d82 0642309 38ca16d 0642309 38ca16d be73d82 0642309 be73d82 0642309 be73d82 0642309 4639733 be73d82 0642309 4639733 be73d82 0642309 be73d82 e1dcf50 0642309 be73d82 9e37bc0 0642309 be73d82 0642309 27d8a18 be73d82 0642309 be73d82 df7fabc 27d8a18 0642309 27d8a18 0642309 2e0024f 0642309 9e37bc0 be73d82 27d8a18 0642309 fab092e 0642309 9e37bc0 6045018 9e37bc0 be73d82 6045018 0642309 bf64010 0642309 9e37bc0 0642309 be73d82 0642309 a87f75c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 |
import torch
import torch.nn.functional as F
import gradio as gr
import librosa
import numpy as np
import cv2
import timm
import os
import time
import spaces
import plotly.express as px
from huggingface_hub import hf_hub_download
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
ASTFeatureExtractor,
ASTForAudioClassification,
AutoModelForCausalLM,
AutoTokenizer
)
from moviepy import VideoFileClip
# --- Configuration ---
CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ==========================================
# 1. CHARGEMENT DES MODÈLES
# ==========================================
def load_models():
print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...")
# Modèle VLM
vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
vlm_model = AutoModelForImageTextToText.from_pretrained(
vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(DEVICE).eval()
# LLM Juge
llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
llm_tok = AutoTokenizer.from_pretrained(llm_id)
llm_model = AutoModelForCausalLM.from_pretrained(
llm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(DEVICE).eval()
# Audio models
audio_models = {}
for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
path = hf_hub_download(repo_id=repo, filename=f)
m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3)
m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
audio_models[p] = m.to(DEVICE).eval()
path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth")
model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
audio_models['C'] = model_c.to(DEVICE).eval()
audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
return vlm_model, llm_tok, llm_model, audio_models
# Chargement global
vlm_model, llm_tok, llm_model, audio_models = load_models()
# ==========================================
# 2. JUGE HYBRIDE (règles + LLM)
# ==========================================
def call_peace_judge(audio_ctx, vlm_desc):
"""
Deterministic + LLM hybrid judge.
AUDIO dominates when confidence > 30%.
Vision can refine but never neutralize strong audio signals.
"""
vlm_lower = vlm_desc.lower()
audio_upper = audio_ctx.upper()
# =====================================================
# 1. HARD AUDIO GUARDRAILS (ABSOLUTE PRIORITY)
# =====================================================
if "PAIN" in audio_upper:
return "The cat is in pain."
if "ANGRY" in audio_upper:
return "The cat is angry."
if "DEFENSIVE" in audio_upper:
return "The cat is defensive."
if "BACK_OFF" in audio_upper or "BACKING_OFF" in audio_upper:
return "The cat is backing off."
# =====================================================
# 2. HARD VISUAL OVERRIDES (SAFETY FIRST)
# =====================================================
# Aggression / threat display
if any(x in vlm_lower for x in [
"front paws raised", "paws raised", "swiping",
"hissing", "mouth open and tense"
]):
return "The cat is angry."
# Defensive posture
if any(x in vlm_lower for x in [
"arched back", "puffed fur", "ears flat",
"ears back", "sideways stance"
]):
return "The cat is defensive."
# Pain indicators
if any(x in vlm_lower for x in [
"limping", "hunched", "crouched low",
"guarding", "withdrawn posture"
]):
return "The cat is in pain."
# =====================================================
# 3. POSITIVE / LOW-RISK VISUAL STATES
# =====================================================
if any(x in vlm_lower for x in [
"kneading", "rubbing", "head bumping"
]):
return "The cat is affectionate."
if any(x in vlm_lower for x in [
"playful", "rolling", "pouncing"
]):
return "The cat is happy."
if any(x in vlm_lower for x in [
"stalking", "tail twitching", "low crawl"
]):
return "The cat is hunting."
if any(x in vlm_lower for x in [
"approaching human", "following human",
"pawing at leg"
]):
return "The cat is wanting attention."
if any(x in vlm_lower for x in [
"waiting posture", "looking at food",
"pacing near bowl"
]):
return "The cat is hungry."
# =====================================================
# 4. LLM FALLBACK (NO CALM ALLOWED)
# =====================================================
messages = [
{
"role": "system",
"content": (
"You are a strict cat behavior decision engine.\n"
"Rules:\n"
"1. AUDIO has priority over vision.\n"
"2. You must choose the most conservative interpretation.\n"
"3. 'calm' is NOT a valid output.\n"
"4. If unsure, prefer defensive or in pain.\n\n"
"Allowed outputs ONLY:\n"
"affectionate, angry, backing off, defensive, hungry, happy, "
"hunting, in heat, calling kittens, in pain, wanting attention\n\n"
"Answer format EXACTLY:\n"
"The cat is [label]."
)
},
{
"role": "user",
"content": (
f"AUDIO SIGNAL (PRIMARY): {audio_ctx}\n"
f"VISION OBSERVATIONS (SECONDARY): {vlm_desc}\n\n"
"FINAL DECISION:"
)
}
]
input_text = llm_tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = llm_tok(input_text, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = llm_model.generate(
**inputs,
max_new_tokens=15,
do_sample=False,
temperature=0.0,
pad_token_id=llm_tok.eos_token_id,
eos_token_id=llm_tok.eos_token_id
)
generated = llm_tok.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True
).lower()
for cat in CATEGORIES:
if cat.replace("_", " ") in generated:
return f"The cat is {cat.replace('_', ' ')}."
# =====================================================
# 5. FINAL FAILSAFE (NEVER CALM)
# =====================================================
return "The cat is defensive."
# ==========================================
# 3. PIPELINE ANALYSE COMPLETE (CORRIGÉ)
# ==========================================
@spaces.GPU(duration=120)
def analyze_cat_v12_final(video_path):
if not video_path:
return "❌ Aucune vidéo.", None
if torch.cuda.is_available():
torch.cuda.empty_cache()
tmp_audio = f"temp_{os.getpid()}_{int(time.time())}.wav"
start_total = time.time()
# --------------------------------------------------
# Helper: clean VLM repetitions (cheap & mobile-safe)
# --------------------------------------------------
def clean_vlm_output(text):
sentences = text.split(". ")
cleaned = []
seen = set()
for s in sentences:
key = s.strip().lower()
if key and key not in seen:
seen.add(key)
cleaned.append(s.strip())
return ". ".join(cleaned)
try:
# =========================
# A. AUDIO
# =========================
t_0 = time.time()
clip = VideoFileClip(video_path)
audio_probs = np.zeros(len(CATEGORIES))
if clip.audio:
clip.audio.write_audiofile(tmp_audio, fps=16000, logger=None)
w, _ = librosa.load(tmp_audio, sr=16000, duration=5.0)
if len(w) < 48000:
w = np.pad(w, (0, 48000 - len(w)))
mel = librosa.feature.melspectrogram(y=w, sr=16000, n_mels=192)
mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40
img = cv2.resize(
(np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8),
(224, 224)
)
img_t = (
torch.tensor(img)
.unsqueeze(0)
.repeat(1, 3, 1, 1)
.float()
.to(DEVICE) / 255.0
)
with torch.no_grad():
pa = F.softmax(audio_models['A'](img_t), dim=1)
pb = F.softmax(audio_models['B'](img_t), dim=1)
ic = audio_models['ast_ext'](
w, sampling_rate=16000, return_tensors="pt"
).to(DEVICE)
pc = F.softmax(audio_models['C'](**ic).logits, dim=1)
audio_probs = (
pa * 0.3468 + pb * 0.2762 + pc * 0.3770
).cpu().numpy()[0]
clip.close()
t_audio = time.time() - t_0
# =========================
# B. VISION (VLM STABILISÉ)
# =========================
t_1 = time.time()
vlm_proc = AutoProcessor.from_pretrained(
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
)
vlm_prompt = (
"You are a feline behavior expert.\n"
"Describe ONLY observable physical features:\n"
"- ears position\n"
"- mouth state (open/closed/tense)\n"
"- tail position or movement\n"
"- body posture\n"
"Use short factual sentences.\n"
"One observation per sentence.\n"
"Do NOT interpret mood."
)
messages = [
{
"role": "user",
"content": [
{"type": "video", "path": video_path},
{"type": "text", "text": vlm_prompt}
]
}
]
vlm_inputs = vlm_proc.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(DEVICE)
input_length = vlm_inputs["input_ids"].shape[1]
with torch.no_grad():
vlm_out = vlm_model.generate(
**vlm_inputs,
max_new_tokens=80,
do_sample=False,
temperature=0.0,
repetition_penalty=1.15, # 🔑 anti-loop
no_repeat_ngram_size=5, # 🔑 anti-phrase répétée
pad_token_id=vlm_proc.tokenizer.eos_token_id,
eos_token_id=vlm_proc.tokenizer.eos_token_id
)
gen_tokens = vlm_out[0][input_length:]
vlm_clean = vlm_proc.batch_decode(
[gen_tokens], skip_special_tokens=True
)[0]
vlm_clean = vlm_clean.strip().split("\n")[0]
if vlm_clean.lower().startswith("assistant:"):
vlm_clean = vlm_clean.split(":", 1)[-1].strip()
# nettoyage final anti-répétition
vlm_clean = clean_vlm_output(vlm_clean)
if torch.cuda.is_available():
torch.cuda.empty_cache()
t_vlm = time.time() - t_1
# =========================
# C. JUGE
# =========================
t_2 = time.time()
top_idx = np.argmax(audio_probs)
audio_ctx = f"{CATEGORIES[top_idx].upper()} ({audio_probs[top_idx]*100:.1f}%)"
judge_decision = call_peace_judge(audio_ctx, vlm_clean)
t_llm = time.time() - t_2
# =========================
# D. VISUELS
# =========================
top5 = np.argsort(audio_probs)[-5:][::-1]
fig = px.bar(
x=[audio_probs[i] * 100 for i in top5],
y=[CATEGORIES[i].upper() for i in top5],
orientation="h",
title="Top 5 Scores Audio",
labels={"x": "Probabilité (%)", "y": "Émotion"},
color=[audio_probs[i] * 100 for i in top5],
color_continuous_scale="Viridis"
)
fig.update_layout(height=400, showlegend=False)
# =========================
# E. RAPPORT FINAL
# =========================
t_total = time.time() - start_total
report = f"""⚖️ VERDICT JUGE : {judge_decision}
------------------------------------------
👁️ VISION : {vlm_clean}
📊 AUDIO : {audio_ctx}
⏱️ TEMPS : Audio {t_audio:.2f}s | Vision {t_vlm:.2f}s | Juge {t_llm:.2f}s | Total {t_total:.2f}s"""
return report, fig
except Exception as e:
return f"❌ Erreur : {str(e)}", None
finally:
if os.path.exists(tmp_audio):
try:
os.remove(tmp_audio)
except:
pass
# --- Interface Gradio ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode")
gr.Markdown("✅ **SmolVLM2-256M** + **SmolLM2-135M Juge** + Audio Ensemble")
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Vidéo du chat")
btn = gr.Button("🚀 ANALYSER", variant="primary", size="lg")
with gr.Column():
report_out = gr.Textbox(label="Résultat complet", lines=12, interactive=False)
chart_out = gr.Plot(label="Distribution des émotions (Audio)")
btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out])
demo.launch() |