ericjedha commited on
Commit
bd04204
·
verified ·
1 Parent(s): e6bbe3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -87
app.py CHANGED
@@ -1,100 +1,237 @@
1
- import gradio as gr
2
  import torch
3
- from threading import Thread
 
 
 
 
 
 
 
 
 
 
 
 
4
  from transformers import (
5
- AutoProcessor,
6
- AutoModelForVision2Seq,
7
- TextIteratorStreamer,
 
8
  )
 
 
 
 
 
 
9
 
10
- # ======================
11
- # INIT MODÈLE
12
- # ======================
13
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- MODEL_ID = "HuggingFaceTB/SmolVLM2-256M-Instruct" # version compatible HF
15
 
16
- processor = AutoProcessor.from_pretrained(MODEL_ID)
17
- model = AutoModelForVision2Seq.from_pretrained(
18
- MODEL_ID,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
 
20
  ).to(DEVICE).eval()
21
 
22
- # ======================
23
- # STREAMING INFERENCE
24
- # ======================
25
- def analyze_stream(text, image, max_tokens):
26
- if not text.strip() and image is None:
27
- return "❌ Veuillez fournir un texte ou une image."
28
-
29
- # Construire le contenu
30
- content = []
31
- if image:
32
- content.append({"type": "image", "path": image})
33
- if text.strip():
34
- content.append({"type": "text", "text": text})
35
-
36
- messages = [{"role": "user", "content": content}]
37
-
38
- # Préparer les inputs
39
- inputs = processor.apply_chat_template(
40
- messages,
41
- add_generation_prompt=True,
42
- tokenize=True,
43
- return_tensors="pt",
44
- ).to(DEVICE)
45
-
46
- # Créer le streamer
47
- streamer = TextIteratorStreamer(
48
- processor,
49
- skip_prompt=True,
50
- skip_special_tokens=True,
51
- )
52
-
53
- # Lancer la génération dans un thread
54
- Thread(
55
- target=model.generate,
56
- kwargs=dict(
57
- **inputs,
58
- streamer=streamer,
59
- max_new_tokens=max_tokens,
60
- do_sample=False,
61
- temperature=0.0,
62
- ),
63
- ).start()
64
-
65
- # Yield token par token
66
- output = ""
67
- for token in streamer:
68
- output += token
69
- yield output
70
-
71
- # ======================
72
- # INTERFACE GRADIO
73
- # ======================
74
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
75
- gr.Markdown("## ⚡ SmolVLM2 – Analyse en temps réel (Streaming)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  with gr.Row():
78
  with gr.Column():
79
- txt = gr.Textbox(
80
- label="Question / Description",
81
- placeholder="Posez une question ou décrivez l'image",
82
- lines=3,
83
- )
84
- img = gr.Image(type="filepath", label="Image")
85
- max_tokens = gr.Slider(50, 400, value=200, step=50, label="Max Tokens")
86
- btn = gr.Button("🚀 Analyser", variant="primary")
87
-
88
  with gr.Column():
89
- out = gr.Textbox(
90
- label="Réponse en temps réel",
91
- lines=14,
92
- )
93
-
94
- btn.click(
95
- fn=analyze_stream,
96
- inputs=[txt, img, max_tokens],
97
- outputs=out,
98
- )
99
-
100
- demo.launch()
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import librosa
6
+ import numpy as np
7
+ import cv2
8
+ import timm
9
+ import os
10
+ import warnings
11
+ import logging
12
+ import time
13
+ from pathlib import Path
14
+ from huggingface_hub import hf_hub_download
15
  from transformers import (
16
+ AutoProcessor,
17
+ AutoModelForImageTextToText,
18
+ ASTFeatureExtractor,
19
+ ASTForAudioClassification
20
  )
21
+ from PIL import Image
22
+ from moviepy import VideoFileClip
23
+
24
+ # --- Configuration & Silence ---
25
+ logging.getLogger("asyncio").setLevel(logging.CRITICAL)
26
+ warnings.filterwarnings("ignore")
27
 
28
+ CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
29
+ TARGET_SR = 16000
30
+ MAX_SEC = 5.0
31
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
 
33
+ # ==========================================
34
+ # LOGIQUE DE FUSION V6.4 (Priorité Comportementale)
35
+ # ==========================================
36
+ def apply_visual_logic_v6(description, audio_probs):
37
+ scores = audio_probs.copy()
38
+ desc = description.lower()
39
+ applied_rules = []
40
+
41
+ # On détecte les états
42
+ has_teeth = "teeth" in desc
43
+ has_ears_back = "ears back" in desc
44
+ has_ears_forward = "ears forward" in desc
45
+
46
+ # 1. RÈGLE D'OR : Priorité à l'agression
47
+ # Si à n'importe quel moment on a vu les oreilles en arrière OU les dents
48
+ # on ignore totalement le "ears forward" pour les catégories de tension.
49
+ if has_ears_back or has_teeth:
50
+ # On booste les catégories de tension
51
+ for cat in ["angry", "back_off", "defensive"]:
52
+ scores[CATEGORIES.index(cat)] *= 4.0
53
+ applied_rules.append("⚠️ PRIORITÉ AGRESSION (Dents/Oreilles arrière détectées)")
54
+
55
+ # 2. On traite les autres indices normalement s'ils ne contredisent pas l'agression
56
+ if has_ears_forward and not (has_ears_back or has_teeth):
57
+ for cat in ["happy", "hunt", "wants_attention"]:
58
+ scores[CATEGORIES.index(cat)] *= 2.0
59
+ applied_rules.append("✅ Calme/Alerte (Ears forward)")
60
+
61
+ # 3. Les yeux et le front (Indicateurs de tension interne)
62
+ if "eyes wide" in desc:
63
+ for cat in ["angry", "back_off", "pain"]:
64
+ scores[CATEGORIES.index(cat)] *= 2.0
65
+ applied_rules.append("✅ Yeux écarquillés")
66
+
67
+ if np.sum(scores) > 0:
68
+ scores /= np.sum(scores)
69
+ return scores, applied_rules
70
+
71
+ # ==========================================
72
+ # PARSING NARRATIF (Détection de mots-clés)
73
+ # ==========================================
74
+ def parse_narrative_to_indices(text):
75
+ found = set()
76
+ text = f" {text.lower()} " # On ajoute des espaces pour isoler les mots
77
+
78
+ # OREILLES : On cherche des termes précis
79
+ if any(x in text for x in [" back ", "backward", " flat", " down", " low"]):
80
+ found.add("ears back")
81
+ if any(x in text for x in [" forward", "upright", "pointed", " up"]):
82
+ found.add("ears forward")
83
+
84
+ # BOUCHE : On évite de confondre 'open' et 'opening'
85
+ if any(x in text for x in [" open", "hiss", "snarl", "meow", "yawn"]):
86
+ found.add("mouth open")
87
+ if any(x in text for x in ["teeth", "fangs", "sharp"]):
88
+ found.add("teeth")
89
+
90
+ # YEUX & FRONT
91
+ if any(x in text for x in ["wide", "dilated", "staring"]):
92
+ found.add("eyes wide")
93
+ if any(x in text for x in ["squint", "closed eyes", "blink"]):
94
+ found.add("eyes squinted")
95
+ if any(x in text for x in ["wrinkl", "furrow", "tense forehead"]):
96
+ found.add("forehead wrinkled")
97
+
98
+ return " ".join(list(found))
99
+
100
+ # ==========================================
101
+ # CHARGEMENT DES MODÈLES AUDIO
102
+ # ==========================================
103
+ def load_audio_models():
104
+ models = {}
105
+ for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
106
+ ('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
107
+ path = hf_hub_download(repo_id=repo, filename=f)
108
+ m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3)
109
+ m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
110
+ models[p] = m.to(DEVICE).eval()
111
+
112
+ path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth")
113
+ model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
114
+ sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
115
+ model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
116
+ models['C'] = model_c.to(DEVICE).eval()
117
+ models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
118
+ return models
119
+
120
+ # ==========================================
121
+ # INITIALISATION
122
+ # ==========================================
123
+ print("📥 Initialisation CatSense v8.4 (Full Face Analysis)...")
124
+ vlm_id = "HuggingFaceTB/SmolVLM2-256M-Instruct"
125
+ vlm_proc = AutoProcessor.from_pretrained(vlm_id)
126
+ vlm_model = AutoModelForImageTextToText.from_pretrained(
127
+ vlm_id,
128
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
129
+ _attn_implementation="sdpa"
130
  ).to(DEVICE).eval()
131
 
132
+ audio_ensemble = load_audio_models()
133
+
134
+ # ==========================================
135
+ # INFERENCE AUDIO
136
+ # ==========================================
137
+ def get_audio_probs(path):
138
+ w, _ = librosa.load(path, sr=TARGET_SR, duration=MAX_SEC)
139
+ if len(w) < 48000: w = np.pad(w, (0, 48000-len(w)))
140
+ mel = librosa.feature.melspectrogram(y=w, sr=TARGET_SR, n_mels=192)
141
+ mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40
142
+ img = cv2.resize((np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8), (224, 224))
143
+ img_t = torch.tensor(img).unsqueeze(0).repeat(1, 3, 1, 1).float().to(DEVICE) / 255.0
144
+ with torch.no_grad():
145
+ pa = F.softmax(audio_ensemble['A'](img_t), dim=1)
146
+ pb = F.softmax(audio_ensemble['B'](img_t), dim=1)
147
+ ic = audio_ensemble['ast_ext'](w, sampling_rate=TARGET_SR, return_tensors="pt").to(DEVICE)
148
+ pc = F.softmax(audio_ensemble['C'](**ic).logits, dim=1)
149
+ return (pa * 0.3468 + pb * 0.2762 + pc * 0.3770).cpu().numpy()[0]
150
+
151
+ # ==========================================
152
+ # ANALYSE PRINCIPALE
153
+ # ==========================================
154
+ def analyze_cat(video_path):
155
+ if video_path is None: return [], "❌ Vidéo absente."
156
+ start_time = time.time()
157
+ tmp_audio = f"temp_{os.getpid()}.wav"
158
+
159
+ try:
160
+ # 1. Extraction Frames & Audio
161
+ with VideoFileClip(video_path) as clip:
162
+ clip.audio.write_audiofile(tmp_audio, fps=TARGET_SR, logger=None)
163
+ duration = min(clip.duration, MAX_SEC)
164
+ ts = [0.1 * duration, 0.5 * duration, 0.9 * duration]
165
+ frames_pil = []
166
+ for t in ts:
167
+ img = Image.fromarray(clip.get_frame(t)).convert("RGB")
168
+ w, h = img.size
169
+ # Crop focalisé
170
+ img = img.crop((int(w*0.12), int(h*0.05), int(w*0.88), int(h*0.85)))
171
+ img.thumbnail((384, 384))
172
+ frames_pil.append(img)
173
+
174
+ # 2. Audio
175
+ raw_audio_probs = get_audio_probs(tmp_audio)
176
+
177
+ # 3. Vision : Prompt Expert complet
178
+ prompt = """Analyze the cat in these 3 frames. List the state of:
179
+ 1. Mouth and Teeth (Visible?)
180
+ 2. Ears (Position?)
181
+ 3. Eyes (Wide or squinted?)
182
+ 4. Forehead (Wrinkled or smooth?)
183
+ 5. Tail (Is it up or down?)
184
+ Be very direct for each image."""
185
+
186
+ messages = [{"role": "user", "content": [{"type": "image"}]*3 + [{"type": "text", "text": prompt}]}]
187
+ text_prompt = vlm_proc.apply_chat_template(messages, add_generation_prompt=True)
188
+ inputs = vlm_proc(text=text_prompt, images=frames_pil, return_tensors="pt").to(DEVICE)
189
+
190
+ with torch.no_grad():
191
+ outputs = vlm_model.generate(**inputs, max_new_tokens=120, do_sample=False, temperature=0.0)
192
+
193
+ vlm_desc = vlm_proc.decode(outputs[0], skip_special_tokens=True).split("assistant")[-1].strip()
194
+
195
+ # 4. Fusion
196
+ visual_mapped = parse_narrative_to_indices(vlm_desc)
197
+ final_probs, indices = apply_visual_logic_v6(visual_mapped, raw_audio_probs)
198
+ final_idx = np.argmax(final_probs)
199
+
200
+ if os.path.exists(tmp_audio): os.remove(tmp_audio)
201
+ elapsed = time.time() - start_time
202
 
203
+ # 5. Rapport
204
+ res = f"🏆 VERDICT : {CATEGORIES[final_idx].upper()}\n"
205
+ res += f"🎯 CONFIANCE : {final_probs[final_idx]:.1%}\n"
206
+ res += f"⏱️ VITESSE : {elapsed:.2f}s\n"
207
+ res += f"------------------------------------------\n"
208
+ res += f"👁️ ANALYSE VISUELLE :\n{vlm_desc}\n"
209
+ res += f"------------------------------------------\n"
210
+ res += f"🔎 SIGNES RETENUS : {', '.join(indices) if indices else 'Aucun'}\n"
211
+ res += f"🔊 AUDIO DOMINANT : {CATEGORIES[np.argmax(raw_audio_probs)].upper()}\n"
212
+ res += f"📊 TOP 3 :\n"
213
+ for i in np.argsort(final_probs)[::-1][:3]:
214
+ res += f" - {CATEGORIES[i]}: {final_probs[i]:.1%}\n"
215
+
216
+ return frames_pil, res
217
+
218
+ except Exception as e:
219
+ if os.path.exists(tmp_audio): os.remove(tmp_audio)
220
+ import traceback
221
+ return [], f"❌ Erreur critique :\n{traceback.format_exc()}"
222
+
223
+ # ==========================================
224
+ # INTERFACE
225
+ # ==========================================
226
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
227
+ gr.Markdown("# 🐱 CatSense v8.4 - Expert Behavior Analysis")
228
  with gr.Row():
229
  with gr.Column():
230
+ v_in = gr.Video(label="Vidéo (5s)")
231
+ btn = gr.Button("🚀 LANCER L'EXPERTISE", variant="primary")
 
 
 
 
 
 
 
232
  with gr.Column():
233
+ gal = gr.Gallery(label="Indices Visuels", columns=3)
234
+ out = gr.Textbox(label="Rapport de Fusion", lines=20)
235
+ btn.click(fn=analyze_cat, inputs=v_in, outputs=[gal, out])
236
+
237
+ demo.launch()