ericjedha commited on
Commit
4639733
·
verified ·
1 Parent(s): adaf41a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -13
app.py CHANGED
@@ -25,23 +25,25 @@ CATEGORIES = ['affection', 'angry', 'back_off', 'defensive', 'feed_me', 'happy',
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  # ==========================================
28
- # 1. CHARGEMENT DE LA TRINITÉ
29
  # ==========================================
30
  def load_models():
31
- print("📥 Initialisation CatSense v12.12 (No-Assistant Fix)...")
32
 
 
33
  vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
34
- vlm_proc = AutoProcessor.from_pretrained(vlm_id)
35
  vlm_model = AutoModelForImageTextToText.from_pretrained(
36
  vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
  ).to(DEVICE).eval()
38
 
 
39
  llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
40
  llm_tok = AutoTokenizer.from_pretrained(llm_id)
41
  llm_model = AutoModelForCausalLM.from_pretrained(
42
  llm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
43
  ).to(DEVICE).eval()
44
 
 
45
  audio_models = {}
46
  for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
47
  ('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
@@ -50,30 +52,38 @@ def load_models():
50
  m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
51
  audio_models[p] = m.to(DEVICE).eval()
52
 
53
- path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pillar_c_ast_v95_2_f1_0_9109.pth")
54
  model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
55
  sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
56
  model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
57
  audio_models['C'] = model_c.to(DEVICE).eval()
58
  audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
59
 
60
- return vlm_proc, vlm_model, llm_tok, llm_model, audio_models
61
 
62
- vlm_proc, vlm_model, llm_tok, llm_model, audio_models = load_models()
 
63
 
64
  # ==========================================
65
- # 2. LOGIQUE DU JUGE
66
  # ==========================================
67
  def call_peace_judge(audio_top, vlm_desc):
68
  prompt_text = f"Audio Score: {audio_top}\nVisual Analysis: {vlm_desc}\nVerdict:"
69
  inputs = llm_tok(prompt_text, return_tensors="pt").to(DEVICE)
70
  with torch.no_grad():
71
- outputs = llm_model.generate(**inputs, max_new_tokens=20, temperature=0.01, do_sample=False)
 
 
 
 
 
 
 
72
  res = llm_tok.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
73
  return res.strip().split('\n')[0]
74
 
75
  # ==========================================
76
- # 3. PIPELINE ANALYSE
77
  # ==========================================
78
  @spaces.GPU(duration=120)
79
  def analyze_cat_v12_final(video_path):
@@ -108,9 +118,12 @@ def analyze_cat_v12_final(video_path):
108
  clip.close()
109
  t_audio = time.time() - t_0
110
 
111
- # --- B. VISION (CORRIGÉ : une seule fois, via apply_chat_template) ---
112
  t_1 = time.time()
113
 
 
 
 
114
  vlm_prompt = (
115
  "Describe the cat in the video\n"
116
  "count ears, mouth, tail and body posture.\n"
@@ -128,11 +141,17 @@ def analyze_cat_v12_final(video_path):
128
  ).to(DEVICE)
129
 
130
  with torch.no_grad():
131
- vlm_out = vlm_model.generate(**vlm_inputs, max_new_tokens=100, do_sample=False)
 
 
 
 
 
 
132
 
133
  vlm_res = vlm_proc.batch_decode(vlm_out, skip_special_tokens=True)[0]
134
 
135
- # Nettoyage robuste de la réponse
136
  if "assistant" in vlm_res.lower():
137
  vlm_clean = vlm_res.split("assistant")[-1].strip()
138
  else:
@@ -175,7 +194,7 @@ def analyze_cat_v12_final(video_path):
175
 
176
  # --- Interface Gradio ---
177
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
178
- gr.Markdown("# 🐱 CatSense v12.12 - Raw Mode")
179
  with gr.Row():
180
  with gr.Column():
181
  video_input = gr.Video()
 
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
  # ==========================================
28
+ # 1. CHARGEMENT DES MODÈLES (sans le VLM processor)
29
  # ==========================================
30
  def load_models():
31
+ print("📥 Initialisation CatSense v12.12 (Fresh Processor Fix)...")
32
 
33
+ # On charge SEULEMENT le modèle VLM (lourd), pas le processor
34
  vlm_id = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
 
35
  vlm_model = AutoModelForImageTextToText.from_pretrained(
36
  vlm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
37
  ).to(DEVICE).eval()
38
 
39
+ # LLM
40
  llm_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
41
  llm_tok = AutoTokenizer.from_pretrained(llm_id)
42
  llm_model = AutoModelForCausalLM.from_pretrained(
43
  llm_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
44
  ).to(DEVICE).eval()
45
 
46
+ # Audio models
47
  audio_models = {}
48
  for p, repo, f in [('A', 'ericjedha/pilier_a', 'best_pillar_a_e29_f1_0_9005.pth'),
49
  ('B', 'ericjedha/pilier_b', 'best_pillar_b_f1_09103.pth')]:
 
52
  m.load_state_dict(torch.load(path, map_location=DEVICE)['model_state_dict'])
53
  audio_models[p] = m.to(DEVICE).eval()
54
 
55
+ path_c = hf_hub_download(repo_id="ericjedha/pilier_c", filename="best_pilier_c_ast_v95_2_f1_0_9109.pth")
56
  model_c = ASTForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=len(CATEGORIES), ignore_mismatched_sizes=True)
57
  sd = torch.load(path_c, map_location=DEVICE)['model_state_dict']
58
  model_c.load_state_dict({k.replace('ast.', ''): v for k, v in sd.items()}, strict=False)
59
  audio_models['C'] = model_c.to(DEVICE).eval()
60
  audio_models['ast_ext'] = ASTFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
61
 
62
+ return vlm_model, llm_tok, llm_model, audio_models
63
 
64
+ # Chargement global des modèles lourds (pas du processor VLM)
65
+ vlm_model, llm_tok, llm_model, audio_models = load_models()
66
 
67
  # ==========================================
68
+ # 2. LOGIQUE DU JUGE (avec stochasticité)
69
  # ==========================================
70
  def call_peace_judge(audio_top, vlm_desc):
71
  prompt_text = f"Audio Score: {audio_top}\nVisual Analysis: {vlm_desc}\nVerdict:"
72
  inputs = llm_tok(prompt_text, return_tensors="pt").to(DEVICE)
73
  with torch.no_grad():
74
+ outputs = llm_model.generate(
75
+ **inputs,
76
+ max_new_tokens=20,
77
+ do_sample=True,
78
+ temperature=0.4,
79
+ top_p=0.9,
80
+ pad_token_id=llm_tok.eos_token_id
81
+ )
82
  res = llm_tok.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
83
  return res.strip().split('\n')[0]
84
 
85
  # ==========================================
86
+ # 3. PIPELINE ANALYSE (Processor VLM FRESH à chaque appel)
87
  # ==========================================
88
  @spaces.GPU(duration=120)
89
  def analyze_cat_v12_final(video_path):
 
118
  clip.close()
119
  t_audio = time.time() - t_0
120
 
121
+ # --- B. VISION (Processor FRESH à chaque appel) ---
122
  t_1 = time.time()
123
 
124
+ # 🔑 CORRECTION MAJEURE : on charge le processor ICI
125
+ vlm_proc = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
126
+
127
  vlm_prompt = (
128
  "Describe the cat in the video\n"
129
  "count ears, mouth, tail and body posture.\n"
 
141
  ).to(DEVICE)
142
 
143
  with torch.no_grad():
144
+ vlm_out = vlm_model.generate(
145
+ **vlm_inputs,
146
+ max_new_tokens=100,
147
+ do_sample=True, # ✅ Stochastic
148
+ temperature=0.7, # ✅ Variabilité
149
+ top_p=0.9
150
+ )
151
 
152
  vlm_res = vlm_proc.batch_decode(vlm_out, skip_special_tokens=True)[0]
153
 
154
+ # Nettoyage robuste
155
  if "assistant" in vlm_res.lower():
156
  vlm_clean = vlm_res.split("assistant")[-1].strip()
157
  else:
 
194
 
195
  # --- Interface Gradio ---
196
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
197
+ gr.Markdown("# 🐱 CatSense v12.12 - Fresh Processor Mode")
198
  with gr.Row():
199
  with gr.Column():
200
  video_input = gr.Video()