ericjedha commited on
Commit
bf64010
·
verified ·
1 Parent(s): 283a965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -84
app.py CHANGED
@@ -11,14 +11,14 @@ import spaces
11
  import plotly.express as px
12
  from huggingface_hub import hf_hub_download
13
  from transformers import (
14
- AutoProcessor,
15
- AutoModelForImageTextToText,
16
- ASTFeatureExtractor,
17
  ASTForAudioClassification,
18
  AutoModelForCausalLM,
19
  AutoTokenizer
20
  )
21
- from moviepy import VideoFileClip
22
 
23
  # --- Configuration ---
24
  CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
@@ -29,9 +29,9 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
  # ==========================================
30
  def load_models():
31
  print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...")
32
-
33
- # On charge SEULEMENT le modèle VLM (lourd), pas le processor
34
- vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
35
  vlm_model = AutoModelForImageTextToText.from_pretrained(
36
  vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
  ).to(DEVICE).eval()
@@ -45,23 +45,23 @@ def load_models():
45
 
46
  # Audio models
47
  audio_models = {}
48
- for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
49
- ('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
50
  path = hf_hub_download(repo_id=repo, filename=f)
51
  m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3)
52
  m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
53
  audio_models[p] = m.to(DEVICE).eval()
54
-
55
  path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth")
56
  model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
57
  sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
58
  model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
59
  audio_models['C'] = model_c.to(DEVICE).eval()
60
  audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
61
-
62
  return vlm_model, llm_tok, llm_model, audio_models
63
 
64
- # Chargement global des modèles lourds (pas du processor VLM)
65
  vlm_model, llm_tok, llm_model, audio_models = load_models()
66
 
67
  # ==========================================
@@ -77,99 +77,112 @@ def call_peace_judge(audio_top, vlm_desc):
77
  inputs = llm_tok(prompt_text, return_tensors="pt").to(DEVICE)
78
  with torch.no_grad():
79
  outputs = llm_model.generate(
80
- **inputs,
81
- max_new_tokens=25,
82
  do_sample=True,
83
  temperature=0.4,
84
  top_p=0.9,
85
  pad_token_id=llm_tok.eos_token_id
86
  )
87
  res = llm_tok.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
88
- # Nettoyer les sauts de ligne, points, et garder une seule phrase
89
  res = res.strip().split('\n')[0].split('.')[0].strip()
90
  if not res.startswith("The cat"):
91
  res = "The cat " + res.lower()
92
  return res
93
 
94
  # ==========================================
95
- # 3. PIPELINE ANALYSE (Processor VLM FRESH à chaque appel)
96
  # ==========================================
97
  @spaces.GPU(duration=120)
98
  def analyze_cat_v12_final(video_path):
99
- if not video_path:
100
  return "❌ Aucune vidéo.", None
101
- if torch.cuda.is_available():
102
- torch.cuda.empty_cache()
103
 
104
- tmp_audio = f"temp_{os.getpid()}.wav"
 
 
 
105
  start_total = time.time()
106
-
107
  try:
108
  # --- A. AUDIO ---
109
  t_0 = time.time()
110
  clip = VideoFileClip(video_path)
111
  audio_probs = np.zeros(len(CATEGORIES))
 
112
  if clip.audio:
113
- clip.audio.write_audiofile(tmp_audio, fps=16000, logger=None)
114
  w, _ = librosa.load(tmp_audio, sr=16000, duration=5.0)
115
- if len(w) < 48000:
116
  w = np.pad(w, (0, 48000 - len(w)))
 
117
  mel = librosa.feature.melspectrogram(y=w, sr=16000, n_mels=192)
118
  mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40
119
- img = cv2.resize((np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8), (224, 224))
 
 
 
120
  img_t = torch.tensor(img).unsqueeze(0).repeat(1, 3, 1, 1).float().to(DEVICE) / 255.0
 
121
  with torch.no_grad():
122
  pa = F.softmax(audio_models['A'](img_t), dim=1)
123
  pb = F.softmax(audio_models['B'](img_t), dim=1)
124
  ic = audio_models['ast_ext'](w, sampling_rate=16000, return_tensors="pt").to(DEVICE)
125
  pc = F.softmax(audio_models['C'](**ic).logits, dim=1)
126
  audio_probs = (pa * 0.3468 + pb * 0.2762 + pc * 0.3770).cpu().numpy()[0]
 
127
  clip.close()
128
  t_audio = time.time() - t_0
129
 
130
- # --- B. VISION (Processor FRESH à chaque appel) ---
131
- t_1 = time.time()
132
-
133
- vlm_proc = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256VLM2-Video-Instruct")
134
-
135
- vlm_prompt = (
136
- "You are a feline behavior expert. "
137
- "Analyze precisely: number and position of ears, state of mouth (open/closed/tense), tail position and movement, and overall body posture. "
138
- "Do not interpret mood. Only describe observable features."
139
- )
140
-
141
- messages = [{"role": "user", "content": [{"type": "video", "path": video_path}, {"type": "text", "text": vlm_prompt}]}]
142
-
143
- # Tokenize avec retour des inputs
144
- vlm_inputs = vlm_proc.apply_chat_template(
145
- messages,
146
- add_generation_prompt=True,
147
- tokenize=True,
148
- return_dict=True,
149
- return_tensors="pt"
150
- ).to(DEVICE)
151
-
152
- input_length = vlm_inputs["input_ids"].shape[1] # 🔑 nombre de tokens du prompt
153
-
154
- with torch.no_grad():
155
- vlm_out = vlm_model.generate(
156
- **vlm_inputs,
157
- max_new_tokens=80,
158
- do_sample=True,
159
- temperature=0.7,
160
- top_p=0.9
161
- )
162
-
163
- # 🔑 DÉCODAGE SÉCURISÉ : uniquement les nouveaux tokens
164
- gen_tokens = vlm_out[0][input_length:]
165
- vlm_clean = vlm_proc.batch_decode(gen_tokens.unsqueeze(0), skip_special_tokens=True)[0]
166
-
167
- # Nettoyage final : une seule phrase, sans "Assistant:"
168
- vlm_clean = vlm_clean.strip().split('\n')[0]
169
- if vlm_clean.lower().startswith("assistant:"):
170
- vlm_clean = vlm_clean.split(":", 1)[-1].strip()
171
-
172
- t_vlm = time.time() - t_1
 
 
 
 
 
 
173
 
174
  # --- C. JUGE ---
175
  t_2 = time.time()
@@ -181,39 +194,48 @@ t_vlm = time.time() - t_1
181
  # --- D. VISUELS ---
182
  top5 = np.argsort(audio_probs)[-5:][::-1]
183
  fig = px.bar(
184
- x=[audio_probs[i]*100 for i in top5],
185
- y=[CATEGORIES[i].upper() for i in top5],
186
- orientation='h',
187
- title='Scores Audio'
 
 
 
188
  )
 
189
 
190
  # --- E. RAPPORT ---
191
  t_total = time.time() - start_total
192
  report = f"""⚖️ VERDICT JUGE : {judge_decision}
193
- ------------------------------------------
194
- 👁️ VISION : {vlm_clean}
195
- 📊 AUDIO : {audio_ctx}
196
- ⏱️ TEMPS : Audio {t_audio:.2f}s | Vision {t_vlm:.2f}s | Total {t_total:.2f}s"""
197
 
198
- if os.path.exists(tmp_audio):
199
- os.remove(tmp_audio)
 
 
 
200
  return report, fig
201
-
202
  except Exception as e:
203
- if os.path.exists(tmp_audio):
204
- os.remove(tmp_audio)
205
  return f"❌ Erreur : {str(e)}", None
 
 
 
 
 
 
 
206
 
207
  # --- Interface Gradio ---
208
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
209
  gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode")
210
  with gr.Row():
211
  with gr.Column():
212
- video_input = gr.Video()
213
- btn = gr.Button("🚀 ANALYSER", variant="primary")
214
  with gr.Column():
215
- report_out = gr.Textbox(label="Résultat", lines=12)
216
- chart_out = gr.Plot()
 
217
  btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out])
218
 
219
  demo.launch()
 
11
  import plotly.express as px
12
  from huggingface_hub import hf_hub_download
13
  from transformers import (
14
+ AutoProcessor,
15
+ AutoModelForImageTextToText,
16
+ ASTFeatureExtractor,
17
  ASTForAudioClassification,
18
  AutoModelForCausalLM,
19
  AutoTokenizer
20
  )
21
+ from moviepy.editor import VideoFileClip # Correction : import correct
22
 
23
  # --- Configuration ---
24
  CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
 
29
  # ==========================================
30
  def load_models():
31
  print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...")
32
+
33
+ # Modèle VLM (seulement le modèle, pas le processor)
34
+ vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
35
  vlm_model = AutoModelForImageTextToText.from_pretrained(
36
  vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
  ).to(DEVICE).eval()
 
45
 
46
  # Audio models
47
  audio_models = {}
48
+ for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
49
+ ('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
50
  path = hf_hub_download(repo_id=repo, filename=f)
51
  m = timm.create_model("vit_small_patch16_224", num_classes=len(CATEGORIES), in_chans=3)
52
  m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
53
  audio_models[p] = m.to(DEVICE).eval()
54
+
55
  path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth")
56
  model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
57
  sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
58
  model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
59
  audio_models['C'] = model_c.to(DEVICE).eval()
60
  audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
61
+
62
  return vlm_model, llm_tok, llm_model, audio_models
63
 
64
+ # Chargement global des modèles lourds
65
  vlm_model, llm_tok, llm_model, audio_models = load_models()
66
 
67
  # ==========================================
 
77
  inputs = llm_tok(prompt_text, return_tensors="pt").to(DEVICE)
78
  with torch.no_grad():
79
  outputs = llm_model.generate(
80
+ **inputs,
81
+ max_new_tokens=25,
82
  do_sample=True,
83
  temperature=0.4,
84
  top_p=0.9,
85
  pad_token_id=llm_tok.eos_token_id
86
  )
87
  res = llm_tok.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
 
88
  res = res.strip().split('\n')[0].split('.')[0].strip()
89
  if not res.startswith("The cat"):
90
  res = "The cat " + res.lower()
91
  return res
92
 
93
  # ==========================================
94
+ # 3. PIPELINE ANALYSE
95
  # ==========================================
96
  @spaces.GPU(duration=120)
97
  def analyze_cat_v12_final(video_path):
98
+ if not video_path:
99
  return "❌ Aucune vidéo.", None
 
 
100
 
101
+ if torch.cuda.is_available():
102
+ torch.cuda.empty_cache()
103
+
104
+ tmp_audio = f"temp_{os.getpid()}_{int(time.time())}.wav"
105
  start_total = time.time()
106
+
107
  try:
108
  # --- A. AUDIO ---
109
  t_0 = time.time()
110
  clip = VideoFileClip(video_path)
111
  audio_probs = np.zeros(len(CATEGORIES))
112
+
113
  if clip.audio:
114
+ clip.audio.write_audiofile(tmp_audio, fps=16000, logger=None, verbose=False)
115
  w, _ = librosa.load(tmp_audio, sr=16000, duration=5.0)
116
+ if len(w) < 48000:
117
  w = np.pad(w, (0, 48000 - len(w)))
118
+
119
  mel = librosa.feature.melspectrogram(y=w, sr=16000, n_mels=192)
120
  mel_db = (librosa.power_to_db(mel, ref=np.max) + 40) / 40
121
+ img = cv2.resize(
122
+ (np.vstack([mel_db, np.zeros((10, mel_db.shape[1]))]) * 255).astype(np.uint8),
123
+ (224, 224)
124
+ )
125
  img_t = torch.tensor(img).unsqueeze(0).repeat(1, 3, 1, 1).float().to(DEVICE) / 255.0
126
+
127
  with torch.no_grad():
128
  pa = F.softmax(audio_models['A'](img_t), dim=1)
129
  pb = F.softmax(audio_models['B'](img_t), dim=1)
130
  ic = audio_models['ast_ext'](w, sampling_rate=16000, return_tensors="pt").to(DEVICE)
131
  pc = F.softmax(audio_models['C'](**ic).logits, dim=1)
132
  audio_probs = (pa * 0.3468 + pb * 0.2762 + pc * 0.3770).cpu().numpy()[0]
133
+
134
  clip.close()
135
  t_audio = time.time() - t_0
136
 
137
+ # --- B. VISION (Processor chargé à chaque appel pour éviter les fuites mémoire) ---
138
+ t_1 = time.time()
139
+ vlm_proc = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
140
+
141
+ vlm_prompt = (
142
+ "You are a feline behavior expert. "
143
+ "Analyze precisely: number and position of ears, state of mouth (open/closed/tense), tail position and movement, and overall body posture. "
144
+ "Do not interpret mood. Only describe observable features."
145
+ )
146
+
147
+ messages = [
148
+ {
149
+ "role": "user",
150
+ "content": [
151
+ {"type": "video", "path": video_path},
152
+ {"type": "text", "text": vlm_prompt}
153
+ ]
154
+ }
155
+ ]
156
+
157
+ vlm_inputs = vlm_proc.apply_chat_template(
158
+ messages,
159
+ add_generation_prompt=True,
160
+ tokenize=True,
161
+ return_dict=True,
162
+ return_tensors="pt"
163
+ ).to(DEVICE)
164
+
165
+ input_length = vlm_inputs["input_ids"].shape[1]
166
+
167
+ with torch.no_grad():
168
+ vlm_out = vlm_model.generate(
169
+ **vlm_inputs,
170
+ max_new_tokens=80,
171
+ do_sample=True,
172
+ temperature=0.7,
173
+ top_p=0.9
174
+ )
175
+
176
+ gen_tokens = vlm_out[0][input_length:]
177
+ vlm_clean = vlm_proc.batch_decode([gen_tokens], skip_special_tokens=True)[0]
178
+ vlm_clean = vlm_clean.strip().split('\n')[0]
179
+ if vlm_clean.lower().startswith("assistant:"):
180
+ vlm_clean = vlm_clean.split(":", 1)[-1].strip()
181
+
182
+ if torch.cuda.is_available():
183
+ torch.cuda.empty_cache()
184
+
185
+ t_vlm = time.time() - t_1
186
 
187
  # --- C. JUGE ---
188
  t_2 = time.time()
 
194
  # --- D. VISUELS ---
195
  top5 = np.argsort(audio_probs)[-5:][::-1]
196
  fig = px.bar(
197
+ x=[audio_probs[i]*100 for i in top5],
198
+ y=[CATEGORIES[i].upper() for i in top5],
199
+ orientation='h',
200
+ title='Top 5 Scores Audio',
201
+ labels={'x': 'Probabilité (%)', 'y': 'Émotion'},
202
+ color=[audio_probs[i]*100 for i in top5],
203
+ color_continuous_scale='Viridis'
204
  )
205
+ fig.update_layout(height=400, showlegend=False)
206
 
207
  # --- E. RAPPORT ---
208
  t_total = time.time() - start_total
209
  report = f"""⚖️ VERDICT JUGE : {judge_decision}
 
 
 
 
210
 
211
+ ------------------------------------------
212
+ 👁️ VISION : {vlm_clean}
213
+ 📊 AUDIO : {audio_ctx}
214
+ ⏱️ TEMPS : Audio {t_audio:.2f}s | Vision {t_vlm:.2f}s | Juge {t_llm:.2f}s | Total {t_total:.2f}s"""
215
+
216
  return report, fig
217
+
218
  except Exception as e:
 
 
219
  return f"❌ Erreur : {str(e)}", None
220
+
221
+ finally:
222
+ if os.path.exists(tmp_audio):
223
+ try:
224
+ os.remove(tmp_audio)
225
+ except:
226
+ pass
227
 
228
  # --- Interface Gradio ---
229
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
230
  gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode")
231
  with gr.Row():
232
  with gr.Column():
233
+ video_input = gr.Video(label="Vidéo du chat")
234
+ btn = gr.Button("🚀 ANALYSER", variant="primary", size="lg")
235
  with gr.Column():
236
+ report_out = gr.Textbox(label="Résultat complet", lines=12, interactive=False)
237
+ chart_out = gr.Plot(label="Distribution des émotions (Audio)")
238
+
239
  btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out])
240
 
241
  demo.launch()