ericjedha commited on
Commit
0aa2638
·
verified ·
1 Parent(s): 6a45583

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -224
app.py CHANGED
@@ -1,237 +1,98 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
  import gradio as gr
5
- import librosa
6
- import numpy as np
7
- import cv2
8
- import timm
9
- import os
10
- import warnings
11
- import logging
12
- import time
13
- from pathlib import Path
14
- from huggingface_hub import hf_hub_download
15
  from transformers import (
16
- AutoProcessor,
17
- AutoModelForImageTextToText,
18
- ASTFeatureExtractor,
19
- ASTForAudioClassification
20
  )
21
- from PIL import Image
22
- from moviepy import VideoFileClip
23
 
24
- # --- Configuration & Silence ---
25
- logging.getLogger("asyncio").setLevel(logging.CRITICAL)
26
- warnings.filterwarnings("ignore")
27
-
28
- CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
29
- TARGET_SR = 16000
30
- MAX_SEC = 5.0
31
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
32
 
33
- # ==========================================
34
- # LOGIQUE DE FUSION V6.4 (Priorité Comportementale)
35
- # ==========================================
36
- def apply_visual_logic_v6(description, audio_probs):
37
- scores = audio_probs.copy()
38
- desc = description.lower()
39
- applied_rules = []
40
-
41
- # On détecte les états
42
- has_teeth = "teeth" in desc
43
- has_ears_back = "ears back" in desc
44
- has_ears_forward = "ears forward" in desc
45
-
46
- # 1. RÈGLE D'OR : Priorité à l'agression
47
- # Si à n'importe quel moment on a vu les oreilles en arrière OU les dents
48
- # on ignore totalement le "ears forward" pour les catégories de tension.
49
- if has_ears_back or has_teeth:
50
- # On booste les catégories de tension
51
- for cat in ["angry", "back_off", "defensive"]:
52
- scores[CATEGORIES.index(cat)] *= 4.0
53
- applied_rules.append("⚠️ PRIORITÉ AGRESSION (Dents/Oreilles arrière détectées)")
54
-
55
- # 2. On traite les autres indices normalement s'ils ne contredisent pas l'agression
56
- if has_ears_forward and not (has_ears_back or has_teeth):
57
- for cat in ["happy", "hunt", "wants_attention"]:
58
- scores[CATEGORIES.index(cat)] *= 2.0
59
- applied_rules.append("✅ Calme/Alerte (Ears forward)")
60
-
61
- # 3. Les yeux et le front (Indicateurs de tension interne)
62
- if "eyes wide" in desc:
63
- for cat in ["angry", "back_off", "pain"]:
64
- scores[CATEGORIES.index(cat)] *= 2.0
65
- applied_rules.append("✅ Yeux écarquillés")
66
-
67
- if np.sum(scores) > 0:
68
- scores /= np.sum(scores)
69
- return scores, applied_rules
70
-
71
- # ==========================================
72
- # PARSING NARRATIF (Détection de mots-clés)
73
- # ==========================================
74
- def parse_narrative_to_indices(text):
75
- found = set()
76
- text = f" {text.lower()} " # On ajoute des espaces pour isoler les mots
77
-
78
- # OREILLES : On cherche des termes précis
79
- if any(x in text for x in [" back ", "backward", " flat", " down", " low"]):
80
- found.add("ears back")
81
- if any(x in text for x in [" forward", "upright", "pointed", " up"]):
82
- found.add("ears forward")
83
-
84
- # BOUCHE : On évite de confondre 'open' et 'opening'
85
- if any(x in text for x in [" open", "hiss", "snarl", "meow", "yawn"]):
86
- found.add("mouth open")
87
- if any(x in text for x in ["teeth", "fangs", "sharp"]):
88
- found.add("teeth")
89
-
90
- # YEUX & FRONT
91
- if any(x in text for x in ["wide", "dilated", "staring"]):
92
- found.add("eyes wide")
93
- if any(x in text for x in ["squint", "closed eyes", "blink"]):
94
- found.add("eyes squinted")
95
- if any(x in text for x in ["wrinkl", "furrow", "tense forehead"]):
96
- found.add("forehead wrinkled")
97
-
98
- return " ".join(list(found))
99
-
100
- # ==========================================
101
- # CHARGEMENT DES MODÈLES AUDIO
102
- # ==========================================
103
- def load_audio_models():
104
- models = {}
105
- for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
106
- ('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
107
- path = hf_hub_download(repo_id=repo, filename=f)
108
- m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3)
109
- m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
110
- models[p] = m.to(DEVICE).eval()
111
-
112
- path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth")
113
- model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
114
- sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
115
- model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
116
- models['C'] = model_c.to(DEVICE).eval()
117
- models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
118
- return models
119
-
120
- # ==========================================
121
- # INITIALISATION
122
- # ==========================================
123
- print("📥 Initialisation CatSense v8.4 (Full Face Analysis)...")
124
- vlm_id = "HuggingFaceTB/SmolVLM2-256M-Instruct"
125
- vlm_proc = AutoProcessor.from_pretrained(vlm_id)
126
- vlm_model = AutoModelForImageTextToText.from_pretrained(
127
- vlm_id,
128
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
129
- _attn_implementation="sdpa"
130
  ).to(DEVICE).eval()
131
 
132
- audio_ensemble = load_audio_models()
133
-
134
- # ==========================================
135
- # INFERENCE AUDIO
136
- # ==========================================
137
- def get_audio_probs(path):
138
- w, _ = librosa.load(path, sr=TARGET_SR, duration=MAX_SEC)
139
- if len(w) < 48000: w = np.pad(w, (0, 48000-len(w)))
140
- mel = librosa.feature.melspectrogram(y=w, sr=TARGET_SR, n_mels=192)
141
- mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40
142
- img = cv2.resize((np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8), (224, 224))
143
- img_t = torch.tensor(img).unsqueeze(0).repeat(1, 3, 1, 1).float().to(DEVICE) / 255.0
144
- with torch.no_grad():
145
- pa = F.softmax(audio_ensemble['A'](img_t), dim=1)
146
- pb = F.softmax(audio_ensemble['B'](img_t), dim=1)
147
- ic = audio_ensemble['ast_ext'](w, sampling_rate=TARGET_SR, return_tensors="pt").to(DEVICE)
148
- pc = F.softmax(audio_ensemble['C'](**ic).logits, dim=1)
149
- return (pa * 0.3468 + pb * 0.2762 + pc * 0.3770).cpu().numpy()[0]
150
-
151
- # ==========================================
152
- # ANALYSE PRINCIPALE
153
- # ==========================================
154
- def analyze_cat(video_path):
155
- if video_path is None: return [], "❌ Vidéo absente."
156
- start_time = time.time()
157
- tmp_audio = f"temp_{os.getpid()}.wav"
158
-
159
- try:
160
- # 1. Extraction Frames & Audio
161
- with VideoFileClip(video_path) as clip:
162
- clip.audio.write_audiofile(tmp_audio, fps=TARGET_SR, logger=None)
163
- duration = min(clip.duration, MAX_SEC)
164
- ts = [0.1 * duration, 0.5 * duration, 0.9 * duration]
165
- frames_pil = []
166
- for t in ts:
167
- img = Image.fromarray(clip.get_frame(t)).convert("RGB")
168
- w, h = img.size
169
- # Crop focalisé
170
- img = img.crop((int(w*0.12), int(h*0.05), int(w*0.88), int(h*0.85)))
171
- img.thumbnail((384, 384))
172
- frames_pil.append(img)
173
-
174
- # 2. Audio
175
- raw_audio_probs = get_audio_probs(tmp_audio)
176
-
177
- # 3. Vision : Prompt Expert complet
178
- prompt = """Analyze the cat in these 3 frames. List the state of:
179
- 1. Mouth and Teeth (Visible?)
180
- 2. Ears (Position?)
181
- 3. Eyes (Wide or squinted?)
182
- 4. Forehead (Wrinkled or smooth?)
183
- 5. Tail (Is it up or down?)
184
- Be very direct for each image."""
185
-
186
- messages = [{"role": "user", "content": [{"type": "image"}]*3 + [{"type": "text", "text": prompt}]}]
187
- text_prompt = vlm_proc.apply_chat_template(messages, add_generation_prompt=True)
188
- inputs = vlm_proc(text=text_prompt, images=frames_pil, return_tensors="pt").to(DEVICE)
189
-
190
- with torch.no_grad():
191
- outputs = vlm_model.generate(**inputs, max_new_tokens=120, do_sample=False, temperature=0.0)
192
-
193
- vlm_desc = vlm_proc.decode(outputs[0], skip_special_tokens=True).split("assistant")[-1].strip()
194
-
195
- # 4. Fusion
196
- visual_mapped = parse_narrative_to_indices(vlm_desc)
197
- final_probs, indices = apply_visual_logic_v6(visual_mapped, raw_audio_probs)
198
- final_idx = np.argmax(final_probs)
199
-
200
- if os.path.exists(tmp_audio): os.remove(tmp_audio)
201
- elapsed = time.time() - start_time
202
-
203
- # 5. Rapport
204
- res = f"🏆 VERDICT : {CATEGORIES[final_idx].upper()}\n"
205
- res += f"🎯 CONFIANCE : {final_probs[final_idx]:.1%}\n"
206
- res += f"⏱️ VITESSE : {elapsed:.2f}s\n"
207
- res += f"------------------------------------------\n"
208
- res += f"👁️ ANALYSE VISUELLE :\n{vlm_desc}\n"
209
- res += f"------------------------------------------\n"
210
- res += f"🔎 SIGNES RETENUS : {', '.join(indices) if indices else 'Aucun'}\n"
211
- res += f"🔊 AUDIO DOMINANT : {CATEGORIES[np.argmax(raw_audio_probs)].upper()}\n"
212
- res += f"📊 TOP 3 :\n"
213
- for i in np.argsort(final_probs)[::-1][:3]:
214
- res += f" - {CATEGORIES[i]}: {final_probs[i]:.1%}\n"
215
-
216
- return frames_pil, res
217
-
218
- except Exception as e:
219
- if os.path.exists(tmp_audio): os.remove(tmp_audio)
220
- import traceback
221
- return [], f"❌ Erreur critique :\n{traceback.format_exc()}"
222
-
223
- # ==========================================
224
- # INTERFACE
225
- # ==========================================
226
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
227
- gr.Markdown("# 🐱 CatSense v8.4 - Expert Behavior Analysis")
 
228
  with gr.Row():
229
  with gr.Column():
230
- v_in = gr.Video(label="Vidéo (5s)")
231
- btn = gr.Button("🚀 LANCER L'EXPERTISE", variant="primary")
232
- with gr.Column():
233
- gal = gr.Gallery(label="Indices Visuels", columns=3)
234
- out = gr.Textbox(label="Rapport de Fusion", lines=20)
235
- btn.click(fn=analyze_cat, inputs=v_in, outputs=[gal, out])
 
 
 
236
 
237
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from threading import Thread
 
 
 
 
 
 
 
 
4
  from transformers import (
5
+ SmolVLMProcessor,
6
+ AutoModelForImageTextToText,
7
+ TextIteratorStreamer,
 
8
  )
 
 
9
 
10
+ # ======================
11
+ # INIT MODÈLE
12
+ # ======================
 
 
 
 
13
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ MODEL_ID = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
15
 
16
+ processor = SmolVLMProcessor.from_pretrained(MODEL_ID)
17
+ model = AutoModelForImageTextToText.from_pretrained(
18
+ MODEL_ID,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
 
20
  ).to(DEVICE).eval()
21
 
22
+
23
+ # ======================
24
+ # STREAMING INFERENCE
25
+ # ======================
26
+ def analyze_stream(text, image, max_tokens):
27
+ if image is None and not text.strip():
28
+ return "❌ Veuillez fournir un texte ou une image."
29
+
30
+ content = []
31
+ if image:
32
+ content.append({"type": "image", "path": image})
33
+ if text.strip():
34
+ content.append({"type": "text", "text": text})
35
+
36
+ messages = [{"role": "user", "content": content}]
37
+
38
+ inputs = processor.apply_chat_template(
39
+ messages,
40
+ add_generation_prompt=True,
41
+ tokenize=True,
42
+ return_tensors="pt",
43
+ ).to(DEVICE)
44
+
45
+ streamer = TextIteratorStreamer(
46
+ processor,
47
+ skip_prompt=True,
48
+ skip_special_tokens=True,
49
+ )
50
+
51
+ Thread(
52
+ target=model.generate,
53
+ kwargs=dict(
54
+ **inputs,
55
+ streamer=streamer,
56
+ max_new_tokens=max_tokens,
57
+ do_sample=False,
58
+ temperature=0.0,
59
+ ),
60
+ ).start()
61
+
62
+ output = ""
63
+ for token in streamer:
64
+ output += token
65
+ yield output
66
+
67
+
68
+ # ======================
69
+ # UI GRADIO
70
+ # ======================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
72
+ gr.Markdown("## SmolVLM2 Analyse Temps Réel")
73
+
74
  with gr.Row():
75
  with gr.Column():
76
+ txt = gr.Textbox(
77
+ label="Question / Description",
78
+ lines=3,
79
+ )
80
+ img = gr.Image(type="filepath", label="Image")
81
+ max_tokens = gr.Slider(
82
+ 50, 400, value=200, step=50, label="Max Tokens"
83
+ )
84
+ btn = gr.Button("🚀 Analyser", variant="primary")
85
 
86
+ with gr.Column():
87
+ out = gr.Textbox(
88
+ label="Réponse en Temps Réel",
89
+ lines=14,
90
+ )
91
+
92
+ btn.click(
93
+ fn=analyze_stream,
94
+ inputs=[txt, img, max_tokens],
95
+ outputs=out,
96
+ )
97
+
98
+ demo.launch()