ericjedha commited on
Commit
455bb26
·
verified ·
1 Parent(s): a7d8fd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -49
app.py CHANGED
@@ -18,25 +18,25 @@ from transformers import (
18
  AutoModelForCausalLM,
19
  AutoTokenizer
20
  )
21
- from moviepy import VideoFileClip # ← Changement principal : import direct depuis moviepy (v2.x)
22
 
23
  # --- Configuration ---
24
  CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  # ==========================================
28
- # 1. CHARGEMENT DES MODÈLES (sans le VLM processor)
29
  # ==========================================
30
  def load_models():
31
  print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...")
32
 
33
- # Modèle VLM (seulement le modèle, pas le processor)
34
  vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
35
  vlm_model = AutoModelForImageTextToText.from_pretrained(
36
  vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
  ).to(DEVICE).eval()
38
 
39
- # LLM
40
  llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
41
  llm_tok = AutoTokenizer.from_pretrained(llm_id)
42
  llm_model = AutoModelForCausalLM.from_pretrained(
@@ -61,70 +61,50 @@ def load_models():
61
 
62
  return vlm_model, llm_tok, llm_model, audio_models
63
 
64
- # Chargement global des modèles lourds
65
  vlm_model, llm_tok, llm_model, audio_models = load_models()
66
 
67
  # ==========================================
68
- # 2. LOGIQUE DU JUGE (avec stochasticité)
69
  # ==========================================
70
  def call_peace_judge(audio_top, vlm_desc):
71
- # Utilise le format messages + chat template (recommandé par HF pour SmolLM2-Instruct)
72
  messages = [
73
  {
74
  "role": "system",
75
- "content": """You are a cat behavior expert. Use audio result and video description to describe the behavior of the cat: Answer with ONLY one short sentence starting exactly with "The cat is" : use only these words: affectionate, angry, backing off, defensive, hungry, happy, hunting, in heat, calling kittens, in pain, wanting attention, calm.
76
- No explanation. No extra words."""
 
77
  },
78
  {
79
  "role": "user",
80
- "content": f"Audio analysis (most reliable for vocalizations): {audio_top}\nVisual description (posture and body language): {vlm_desc}"
81
  }
82
  ]
83
 
84
- # Applique le chat template correctement
85
- input_text = llm_tok.apply_chat_template(
86
- messages,
87
- tokenize=False,
88
- add_generation_prompt=True # Ajoute le token pour commencer la réponse assistant
89
- )
90
-
91
  inputs = llm_tok(input_text, return_tensors="pt").to(DEVICE)
92
 
93
  with torch.no_grad():
94
  outputs = llm_model.generate(
95
- **inputs,
96
- max_new_tokens=30, # Un peu plus pour être sûr
97
- do_sample=True,
98
- temperature=0.3, # Augmente un peu pour plus de créativité
99
- top_p=0.90,
100
  pad_token_id=llm_tok.eos_token_id,
101
  eos_token_id=llm_tok.eos_token_id
102
  )
103
 
104
- # Décode seulement les nouveaux tokens
105
  generated = llm_tok.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
106
- generated = generated.strip()
107
-
108
- # Nettoyage final
109
- if generated == "":
110
- generated = "is displaying neutral behavior."
111
-
112
- # Force le début si besoin
113
- if not generated.lower().startswith("the cat"):
114
- generated = "The cat " + generated.lower()
115
 
116
- # Garde seulement la première phrase
117
- generated = generated.split('\n')[0].split('.')[0].strip()
118
- if not generated.endswith('.'):
119
- generated += "."
120
 
121
- # Capitalise correctement
122
- if generated.startswith("The cat"):
123
- generated = "The cat" + generated[7:].capitalize()
124
-
125
- return generated
126
  # ==========================================
127
- # 3. PIPELINE ANALYSE
128
  # ==========================================
129
  @spaces.GPU(duration=120)
130
  def analyze_cat_v12_final(video_path):
@@ -167,7 +147,7 @@ def analyze_cat_v12_final(video_path):
167
  clip.close()
168
  t_audio = time.time() - t_0
169
 
170
- # --- B. VISION (Processor chargé à chaque appel) ---
171
  t_1 = time.time()
172
  vlm_proc = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
173
 
@@ -201,9 +181,10 @@ def analyze_cat_v12_final(video_path):
201
  vlm_out = vlm_model.generate(
202
  **vlm_inputs,
203
  max_new_tokens=80,
204
- do_sample=True,
205
- temperature=0.7,
206
- top_p=0.9
 
207
  )
208
 
209
  gen_tokens = vlm_out[0][input_length:]
@@ -217,7 +198,7 @@ def analyze_cat_v12_final(video_path):
217
 
218
  t_vlm = time.time() - t_1
219
 
220
- # --- C. JUGE ---
221
  t_2 = time.time()
222
  top_idx = np.argmax(audio_probs)
223
  audio_ctx = f"{CATEGORIES[top_idx].upper()} ({audio_probs[top_idx]*100:.1f}%)"
@@ -237,7 +218,7 @@ def analyze_cat_v12_final(video_path):
237
  )
238
  fig.update_layout(height=400, showlegend=False)
239
 
240
- # --- E. RAPPORT ---
241
  t_total = time.time() - start_total
242
  report = f"""⚖️ VERDICT JUGE : {judge_decision}
243
 
@@ -261,6 +242,8 @@ def analyze_cat_v12_final(video_path):
261
  # --- Interface Gradio ---
262
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
263
  gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode")
 
 
264
  with gr.Row():
265
  with gr.Column():
266
  video_input = gr.Video(label="Vidéo du chat")
@@ -271,4 +254,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
271
 
272
  btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out])
273
 
274
- demo.launch()
 
18
  AutoModelForCausalLM,
19
  AutoTokenizer
20
  )
21
+ from moviepy import VideoFileClip
22
 
23
  # --- Configuration ---
24
  CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy', 'hunt', 'in_heat', 'mother_call', 'pain', 'wants_attention']
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  # ==========================================
28
+ # 1. CHARGEMENT DES MODÈLES
29
  # ==========================================
30
  def load_models():
31
  print("📥 Initialisation CatSense v12.13 (Vision Pure Mode)...")
32
 
33
+ # Modèle VLM
34
  vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
35
  vlm_model = AutoModelForImageTextToText.from_pretrained(
36
  vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
  ).to(DEVICE).eval()
38
 
39
+ # LLM Juge
40
  llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
41
  llm_tok = AutoTokenizer.from_pretrained(llm_id)
42
  llm_model = AutoModelForCausalLM.from_pretrained(
 
61
 
62
  return vlm_model, llm_tok, llm_model, audio_models
63
 
64
+ # Chargement global
65
  vlm_model, llm_tok, llm_model, audio_models = load_models()
66
 
67
  # ==========================================
68
+ # 2. JUGE OPTIMISÉ (nouveau)
69
  # ==========================================
70
  def call_peace_judge(audio_top, vlm_desc):
 
71
  messages = [
72
  {
73
  "role": "system",
74
+ "content": """You are a cat behavior expert. Match audio prediction with visual description.
75
+ Answer ONLY: "The cat is [ONE WORD]: affectionate/angry/backing_off/defensive/hungry/happy/hunting/in_heat/calling_kittens/in_pain/wanting_attention/calm"
76
+ No explanation. No extra text. Match exactly."""
77
  },
78
  {
79
  "role": "user",
80
+ "content": f"AUDIO: {audio_top}\nVISION: {vlm_desc}\n\nFINAL JUDGEMENT:"
81
  }
82
  ]
83
 
84
+ input_text = llm_tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
85
  inputs = llm_tok(input_text, return_tensors="pt").to(DEVICE)
86
 
87
  with torch.no_grad():
88
  outputs = llm_model.generate(
89
+ **inputs,
90
+ max_new_tokens=20,
91
+ do_sample=False,
92
+ temperature=0.0,
 
93
  pad_token_id=llm_tok.eos_token_id,
94
  eos_token_id=llm_tok.eos_token_id
95
  )
96
 
 
97
  generated = llm_tok.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
98
 
99
+ # Extraction stricte du mot-clé
100
+ for cat in CATEGORIES + ['calm']:
101
+ if cat.replace('_', ' ') in generated.lower():
102
+ return f"The cat is {cat.replace('_', ' ')}."
103
 
104
+ return "The cat is calm."
105
+
 
 
 
106
  # ==========================================
107
+ # 3. PIPELINE ANALYSE COMPLETE
108
  # ==========================================
109
  @spaces.GPU(duration=120)
110
  def analyze_cat_v12_final(video_path):
 
147
  clip.close()
148
  t_audio = time.time() - t_0
149
 
150
+ # --- B. VISION (Ton prompt parfait + params optimisés) ---
151
  t_1 = time.time()
152
  vlm_proc = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
153
 
 
181
  vlm_out = vlm_model.generate(
182
  **vlm_inputs,
183
  max_new_tokens=80,
184
+ do_sample=False, # ✅ Greedy decoding
185
+ temperature=0.0, # ✅ Zéro créativité
186
+ top_p=0.9,
187
+ pad_token_id=vlm_proc.tokenizer.eos_token_id
188
  )
189
 
190
  gen_tokens = vlm_out[0][input_length:]
 
198
 
199
  t_vlm = time.time() - t_1
200
 
201
+ # --- C. JUGE OPTIMISÉ ---
202
  t_2 = time.time()
203
  top_idx = np.argmax(audio_probs)
204
  audio_ctx = f"{CATEGORIES[top_idx].upper()} ({audio_probs[top_idx]*100:.1f}%)"
 
218
  )
219
  fig.update_layout(height=400, showlegend=False)
220
 
221
+ # --- E. RAPPORT FINAL ---
222
  t_total = time.time() - start_total
223
  report = f"""⚖️ VERDICT JUGE : {judge_decision}
224
 
 
242
  # --- Interface Gradio ---
243
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
244
  gr.Markdown("# 🐱 CatSense v12.13 - Vision Pure Mode")
245
+ gr.Markdown("✅ **SmolVLM2-256M** + **SmolLM2-135M Juge** + Audio Ensemble")
246
+
247
  with gr.Row():
248
  with gr.Column():
249
  video_input = gr.Video(label="Vidéo du chat")
 
254
 
255
  btn.click(analyze_cat_v12_final, inputs=video_input, outputs=[report_out, chart_out])
256
 
257
+ demo.launch()