|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
import librosa |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import timm |
|
|
import os |
|
|
import time |
|
|
import spaces |
|
|
import plotly.express as px |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import ( |
|
|
AutoProcessor, |
|
|
AutoModelForImageTextToText, |
|
|
ASTFeatureExtractor, |
|
|
ASTForAudioClassification, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer |
|
|
) |
|
|
from moviepy import VideoFileClip |
|
|
|
|
|
|
|
|
CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention'] |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_models(): |
|
|
print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...") |
|
|
|
|
|
|
|
|
vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct" |
|
|
vlm_model = AutoModelForImageTextToText.from_pretrained( |
|
|
vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
|
|
|
llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct" |
|
|
llm_tok = AutoTokenizer.from_pretrained(llm_id) |
|
|
llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
llm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
).to(DEVICE).eval() |
|
|
|
|
|
|
|
|
audio_models = {} |
|
|
for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'), |
|
|
('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]: |
|
|
path = hf_hub_download(repo_id=repo, filename=f) |
|
|
m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3) |
|
|
m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict']) |
|
|
audio_models[p] = m.to(DEVICE).eval() |
|
|
|
|
|
path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth") |
|
|
model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True) |
|
|
sd = torch.load(path_c, map_location=DEVICE)['model_state_dict'] |
|
|
model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False) |
|
|
audio_models['C'] = model_c.to(DEVICE).eval() |
|
|
audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") |
|
|
|
|
|
return vlm_model, llm_tok, llm_model, audio_models |
|
|
|
|
|
|
|
|
vlm_model, llm_tok, llm_model, audio_models = load_models() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_peace_judge(audio_ctx, vlm_desc): |
|
|
""" |
|
|
Deterministic + LLM hybrid judge. |
|
|
AUDIO dominates when confidence > 30%. |
|
|
Vision can refine but never neutralize strong audio signals. |
|
|
""" |
|
|
|
|
|
vlm_lower = vlm_desc.lower() |
|
|
audio_upper = audio_ctx.upper() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "PAIN" in audio_upper: |
|
|
return "The cat is in pain." |
|
|
|
|
|
if "ANGRY" in audio_upper: |
|
|
return "The cat is angry." |
|
|
|
|
|
if "DEFENSIVE" in audio_upper: |
|
|
return "The cat is defensive." |
|
|
|
|
|
if "BACK_OFF" in audio_upper or "BACKING_OFF" in audio_upper: |
|
|
return "The cat is backing off." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"front paws raised", "paws raised", "swiping", |
|
|
"hissing", "mouth open and tense" |
|
|
]): |
|
|
return "The cat is angry." |
|
|
|
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"arched back", "puffed fur", "ears flat", |
|
|
"ears back", "sideways stance" |
|
|
]): |
|
|
return "The cat is defensive." |
|
|
|
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"limping", "hunched", "crouched low", |
|
|
"guarding", "withdrawn posture" |
|
|
]): |
|
|
return "The cat is in pain." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"kneading", "rubbing", "head bumping" |
|
|
]): |
|
|
return "The cat is affectionate." |
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"playful", "rolling", "pouncing" |
|
|
]): |
|
|
return "The cat is happy." |
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"stalking", "tail twitching", "low crawl" |
|
|
]): |
|
|
return "The cat is hunting." |
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"approaching human", "following human", |
|
|
"pawing at leg" |
|
|
]): |
|
|
return "The cat is wanting attention." |
|
|
|
|
|
if any(x in vlm_lower for x in [ |
|
|
"waiting posture", "looking at food", |
|
|
"pacing near bowl" |
|
|
]): |
|
|
return "The cat is hungry." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": ( |
|
|
"You are a strict cat behavior decision engine.\n" |
|
|
"Rules:\n" |
|
|
"1. AUDIO has priority over vision.\n" |
|
|
"2. You must choose the most conservative interpretation.\n" |
|
|
"3. 'calm' is NOT a valid output.\n" |
|
|
"4. If unsure, prefer defensive or in pain.\n\n" |
|
|
"Allowed outputs ONLY:\n" |
|
|
"affectionate, angry, backing off, defensive, hungry, happy, " |
|
|
"hunting, in heat, calling kittens, in pain, wanting attention\n\n" |
|
|
"Answer format EXACTLY:\n" |
|
|
"The cat is [label]." |
|
|
) |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": ( |
|
|
f"AUDIO SIGNAL (PRIMARY): {audio_ctx}\n" |
|
|
f"VISION OBSERVATIONS (SECONDARY): {vlm_desc}\n\n" |
|
|
"FINAL DECISION:" |
|
|
) |
|
|
} |
|
|
] |
|
|
|
|
|
input_text = llm_tok.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
inputs = llm_tok(input_text, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = llm_model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=15, |
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
pad_token_id=llm_tok.eos_token_id, |
|
|
eos_token_id=llm_tok.eos_token_id |
|
|
) |
|
|
|
|
|
generated = llm_tok.decode( |
|
|
outputs[0][inputs["input_ids"].shape[1]:], |
|
|
skip_special_tokens=True |
|
|
).lower() |
|
|
|
|
|
for cat in CATEGORIES: |
|
|
if cat.replace("_", " ") in generated: |
|
|
return f"The cat is {cat.replace('_', ' ')}." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return "The cat is defensive." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def analyze_cat_v12_final(video_path): |
|
|
if not video_path: |
|
|
return "❌ Aucune vidéo.", None |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
tmp_audio = f"temp_{os.getpid()}_{int(time.time())}.wav" |
|
|
start_total = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_vlm_output(text): |
|
|
sentences = text.split(". ") |
|
|
cleaned = [] |
|
|
seen = set() |
|
|
for s in sentences: |
|
|
key = s.strip().lower() |
|
|
if key and key not in seen: |
|
|
seen.add(key) |
|
|
cleaned.append(s.strip()) |
|
|
return ". ".join(cleaned) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
t_0 = time.time() |
|
|
clip = VideoFileClip(video_path) |
|
|
audio_probs = np.zeros(len(CATEGORIES)) |
|
|
|
|
|
if clip.audio: |
|
|
clip.audio.write_audiofile(tmp_audio, fps=16000, logger=None) |
|
|
w, _ = librosa.load(tmp_audio, sr=16000, duration=5.0) |
|
|
if len(w) < 48000: |
|
|
w = np.pad(w, (0, 48000 - len(w))) |
|
|
|
|
|
mel = librosa.feature.melspectrogram(y=w, sr=16000, n_mels=192) |
|
|
mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40 |
|
|
img = cv2.resize( |
|
|
(np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8), |
|
|
(224, 224) |
|
|
) |
|
|
|
|
|
img_t = ( |
|
|
torch.tensor(img) |
|
|
.unsqueeze(0) |
|
|
.repeat(1, 3, 1, 1) |
|
|
.float() |
|
|
.to(DEVICE) / 255.0 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pa = F.softmax(audio_models['A'](img_t), dim=1) |
|
|
pb = F.softmax(audio_models['B'](img_t), dim=1) |
|
|
ic = audio_models['ast_ext']( |
|
|
w, sampling_rate=16000, return_tensors="pt" |
|
|
).to(DEVICE) |
|
|
pc = F.softmax(audio_models['C'](**ic).logits, dim=1) |
|
|
|
|
|
audio_probs = ( |
|
|
pa * 0.3468 + pb * 0.2762 + pc * 0.3770 |
|
|
).cpu().numpy()[0] |
|
|
|
|
|
clip.close() |
|
|
t_audio = time.time() - t_0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_1 = time.time() |
|
|
vlm_proc = AutoProcessor.from_pretrained( |
|
|
"HuggingFaceTB/SmolVLM2-256M-Video-Instruct" |
|
|
) |
|
|
|
|
|
vlm_prompt = ( |
|
|
"You are a feline behavior expert.\n" |
|
|
"Describe ONLY observable physical features:\n" |
|
|
"- ears position\n" |
|
|
"- mouth state (open/closed/tense)\n" |
|
|
"- tail position or movement\n" |
|
|
"- body posture\n" |
|
|
"Use short factual sentences.\n" |
|
|
"One observation per sentence.\n" |
|
|
"Do NOT interpret mood." |
|
|
) |
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "video", "path": video_path}, |
|
|
{"type": "text", "text": vlm_prompt} |
|
|
] |
|
|
} |
|
|
] |
|
|
|
|
|
vlm_inputs = vlm_proc.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt" |
|
|
).to(DEVICE) |
|
|
|
|
|
input_length = vlm_inputs["input_ids"].shape[1] |
|
|
|
|
|
with torch.no_grad(): |
|
|
vlm_out = vlm_model.generate( |
|
|
**vlm_inputs, |
|
|
max_new_tokens=80, |
|
|
do_sample=False, |
|
|
temperature=0.0, |
|
|
repetition_penalty=1.15, |
|
|
no_repeat_ngram_size=5, |
|
|
pad_token_id=vlm_proc.tokenizer.eos_token_id, |
|
|
eos_token_id=vlm_proc.tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
gen_tokens = vlm_out[0][input_length:] |
|
|
vlm_clean = vlm_proc.batch_decode( |
|
|
[gen_tokens], skip_special_tokens=True |
|
|
)[0] |
|
|
|
|
|
vlm_clean = vlm_clean.strip().split("\n")[0] |
|
|
if vlm_clean.lower().startswith("assistant:"): |
|
|
vlm_clean = vlm_clean.split(":", 1)[-1].strip() |
|
|
|
|
|
|
|
|
vlm_clean = clean_vlm_output(vlm_clean) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
t_vlm = time.time() - t_1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_2 = time.time() |
|
|
top_idx = np.argmax(audio_probs) |
|
|
audio_ctx = f"{CATEGORIES[top_idx].upper()} ({audio_probs[top_idx]*100:.1f}%)" |
|
|
judge_decision = call_peace_judge(audio_ctx, vlm_clean) |
|
|
t_llm = time.time() - t_2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
top5 = np.argsort(audio_probs)[-5:][::-1] |
|
|
fig = px.bar( |
|
|
x=[audio_probs[i] * 100 for i in top5], |
|
|
y=[CATEGORIES[i].upper() for i in top5], |
|
|
orientation="h", |
|
|
title="Top 5 Scores Audio", |
|
|
labels={"x": "Probabilité (%)", "y": "Émotion"}, |
|
|
color=[audio_probs[i] * 100 for i in top5], |
|
|
color_continuous_scale="Viridis" |
|
|
) |
|
|
fig.update_layout(height=400, showlegend=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t_total = time.time() - start_total |
|
|
report = f"""⚖️ VERDICT JUGE : {judge_decision} |
|
|
------------------------------------------ |
|
|
👁️ VISION : {vlm_clean} |
|
|
📊 AUDIO : {audio_ctx} |
|
|
⏱️ TEMPS : Audio {t_audio:.2f}s | Vision {t_vlm:.2f}s | Juge {t_llm:.2f}s | Total {t_total:.2f}s""" |
|
|
|
|
|
return report, fig |
|
|
|
|
|
except Exception as e: |
|
|
return f"❌ Erreur : {str(e)}", None |
|
|
|
|
|
finally: |
|
|
if os.path.exists(tmp_audio): |
|
|
try: |
|
|
os.remove(tmp_audio) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode") |
|
|
gr.Markdown("✅ **SmolVLM2-256M** + **SmolLM2-135M Juge** + Audio Ensemble") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
video_input = gr.Video(label="Vidéo du chat") |
|
|
btn = gr.Button("🚀 ANALYSER", variant="primary", size="lg") |
|
|
with gr.Column(): |
|
|
report_out = gr.Textbox(label="Résultat complet", lines=12, interactive=False) |
|
|
chart_out = gr.Plot(label="Distribution des émotions (Audio)") |
|
|
|
|
|
btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out]) |
|
|
|
|
|
demo.launch() |