Bliss-Ruth commited on
Commit
908b58d
·
verified ·
1 Parent(s): a2c84bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -27
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py - CLEAN MINIMAL INTERFACE (No Confidence Bars/Tabs)
2
  import torch
3
  import torch.nn as nn
4
  from transformers import XCLIPProcessor, XCLIPModel
@@ -13,10 +13,11 @@ import os
13
  print("🚀 Loading Ugandan Sign Language Model...")
14
 
15
  # ============================================================================
16
- # MODEL SETUP - FIXED VERSION
17
  # ============================================================================
18
 
19
  class MinimalClassifier(nn.Module):
 
20
  def __init__(self, input_dim=512, num_classes=85, dropout=0.5):
21
  super().__init__()
22
  self.classifier = nn.Sequential(
@@ -32,36 +33,24 @@ processor = XCLIPProcessor.from_pretrained("microsoft/xclip-base-patch32")
32
  xclip_model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32").to(device)
33
  xclip_model.eval()
34
 
35
- # Load your trained model - WITH ERROR HANDLING
36
  try:
37
- checkpoint = torch.load("best_xclip_model.pth", map_location=device, weights_only=False)
38
 
39
  # DEBUG: Check what's in the checkpoint
40
  print(f"🔍 Checkpoint keys: {list(checkpoint.keys())}")
41
 
42
- # FIX: Handle missing 'num_classes' key
43
  if 'num_classes' in checkpoint:
44
  num_classes = checkpoint['num_classes']
 
 
45
  else:
46
- # Try to infer number of classes
47
- if 'id_to_label' in checkpoint:
48
- num_classes = len(checkpoint['id_to_label'])
49
- elif 'label_to_id' in checkpoint:
50
- num_classes = len(checkpoint['label_to_id'])
51
- else:
52
- # Count from model weights
53
- for key in checkpoint.keys():
54
- if 'model_state_dict' in checkpoint:
55
- weight_key = [k for k in checkpoint['model_state_dict'].keys() if 'classifier' in k and 'weight' in k]
56
- if weight_key:
57
- num_classes = checkpoint['model_state_dict'][weight_key[0]].shape[0]
58
- break
59
- else:
60
- num_classes = 85 # Default fallback
61
 
62
  print(f"✅ Using num_classes: {num_classes}")
63
 
64
- # Initialize model
65
  model = MinimalClassifier(
66
  input_dim=512,
67
  num_classes=num_classes,
@@ -72,14 +61,12 @@ try:
72
  if 'model_state_dict' in checkpoint:
73
  model.load_state_dict(checkpoint['model_state_dict'])
74
  else:
75
- # If checkpoint IS the state dict
76
  model.load_state_dict(checkpoint)
77
 
78
  # Load label mappings
79
  if 'id_to_label' in checkpoint:
80
  id_to_label = checkpoint['id_to_label']
81
  else:
82
- # Create default mapping
83
  id_to_label = {i: f"class_{i}" for i in range(num_classes)}
84
  print("⚠️ Created default label mapping")
85
 
@@ -91,7 +78,7 @@ try:
91
 
92
  except Exception as e:
93
  print(f"❌ Error loading model: {e}")
94
- print("💡 TIP: Make sure your model file has 'num_classes' or 'id_to_label' key")
95
  exit(1)
96
 
97
  # ============================================================================
@@ -133,7 +120,7 @@ def extract_frames(video_path, num_frames=8):
133
  return [Image.new('RGB', (224, 224), (0, 0, 0)) for _ in range(num_frames)]
134
 
135
  def predict_sign(video_path):
136
- """Predict sign from video"""
137
  try:
138
  frames = extract_frames(video_path)
139
 
@@ -145,6 +132,7 @@ def predict_sign(video_path):
145
  attention_mask = text_inputs['attention_mask'].to(device)
146
 
147
  with torch.no_grad():
 
148
  outputs = xclip_model(
149
  input_ids=input_ids,
150
  attention_mask=attention_mask,
@@ -153,6 +141,7 @@ def predict_sign(video_path):
153
  )
154
  video_embeds = outputs.video_embeds
155
 
 
156
  logits = model(video_embeds)
157
  probs = torch.softmax(logits, dim=1)
158
  confidence, pred_class = torch.max(probs, 1)
@@ -265,6 +254,16 @@ h1 {
265
  border-left: 4px solid #ff6b35 !important;
266
  margin-top: 20px !important;
267
  }
 
 
 
 
 
 
 
 
 
 
268
  """
269
 
270
  def predict_video_clean(video_file):
@@ -331,9 +330,12 @@ with gr.Blocks(css=custom_css, title="Ugandan Sign Language Translator") as demo
331
  with gr.Column(scale=1):
332
  gr.Markdown("### 📤 Upload Video")
333
  video_input = gr.Video(
334
- label="",
335
- sources=["upload"]
 
336
  )
 
 
337
 
338
  # Action buttons
339
  with gr.Row():
@@ -359,6 +361,27 @@ with gr.Blocks(css=custom_css, title="Ugandan Sign Language Translator") as demo
359
  )
360
  feedback_btn = gr.Button("📝 Submit Correction", variant="secondary")
361
  feedback_output = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  # Hidden states
364
  current_prediction = gr.State()
 
1
+ # app.py - CORRECTED VERSION (Uses MinimalClassifier from your training)
2
  import torch
3
  import torch.nn as nn
4
  from transformers import XCLIPProcessor, XCLIPModel
 
13
  print("🚀 Loading Ugandan Sign Language Model...")
14
 
15
  # ============================================================================
16
+ # MODEL SETUP - MINIMALCLASSIFIER (Matches Your Training)
17
  # ============================================================================
18
 
19
  class MinimalClassifier(nn.Module):
20
+ """SIMPLE classifier - matches your training notebook exactly"""
21
  def __init__(self, input_dim=512, num_classes=85, dropout=0.5):
22
  super().__init__()
23
  self.classifier = nn.Sequential(
 
33
  xclip_model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32").to(device)
34
  xclip_model.eval()
35
 
36
+ # Load your trained model - WITH MINIMALCLASSIFIER
37
  try:
38
+ checkpoint = torch.load("finetuned_xclip_model.pth", map_location=device, weights_only=False)
39
 
40
  # DEBUG: Check what's in the checkpoint
41
  print(f"🔍 Checkpoint keys: {list(checkpoint.keys())}")
42
 
43
+ # Get num_classes
44
  if 'num_classes' in checkpoint:
45
  num_classes = checkpoint['num_classes']
46
+ elif 'id_to_label' in checkpoint:
47
+ num_classes = len(checkpoint['id_to_label'])
48
  else:
49
+ num_classes = 85 # Default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  print(f"✅ Using num_classes: {num_classes}")
52
 
53
+ # Initialize with MINIMALCLASSIFIER (your actual architecture)
54
  model = MinimalClassifier(
55
  input_dim=512,
56
  num_classes=num_classes,
 
61
  if 'model_state_dict' in checkpoint:
62
  model.load_state_dict(checkpoint['model_state_dict'])
63
  else:
 
64
  model.load_state_dict(checkpoint)
65
 
66
  # Load label mappings
67
  if 'id_to_label' in checkpoint:
68
  id_to_label = checkpoint['id_to_label']
69
  else:
 
70
  id_to_label = {i: f"class_{i}" for i in range(num_classes)}
71
  print("⚠️ Created default label mapping")
72
 
 
78
 
79
  except Exception as e:
80
  print(f"❌ Error loading model: {e}")
81
+ print("💡 Make sure your model file uses MinimalClassifier architecture")
82
  exit(1)
83
 
84
  # ============================================================================
 
120
  return [Image.new('RGB', (224, 224), (0, 0, 0)) for _ in range(num_frames)]
121
 
122
  def predict_sign(video_path):
123
+ """Predict sign from video """
124
  try:
125
  frames = extract_frames(video_path)
126
 
 
132
  attention_mask = text_inputs['attention_mask'].to(device)
133
 
134
  with torch.no_grad():
135
+ # Extract features using X-CLIP
136
  outputs = xclip_model(
137
  input_ids=input_ids,
138
  attention_mask=attention_mask,
 
141
  )
142
  video_embeds = outputs.video_embeds
143
 
144
+ # Classify with MinimalClassifier (takes features as input)
145
  logits = model(video_embeds)
146
  probs = torch.softmax(logits, dim=1)
147
  confidence, pred_class = torch.max(probs, 1)
 
254
  border-left: 4px solid #ff6b35 !important;
255
  margin-top: 20px !important;
256
  }
257
+
258
+ /* Add to your custom_css */
259
+ #video-upload {
260
+ border: 2px dashed #ff6b35 !important;
261
+ }
262
+
263
+ #video-upload:hover {
264
+ border-color: #e55a2b !important;
265
+ background: #3d3d3d !important;
266
+ }
267
  """
268
 
269
  def predict_video_clean(video_file):
 
330
  with gr.Column(scale=1):
331
  gr.Markdown("### 📤 Upload Video")
332
  video_input = gr.Video(
333
+ label="📱 Upload or Record Video",
334
+ sources=["upload", "webcam"]
335
+ elem_id="video-upload"
336
  )
337
+
338
+
339
 
340
  # Action buttons
341
  with gr.Row():
 
361
  )
362
  feedback_btn = gr.Button("📝 Submit Correction", variant="secondary")
363
  feedback_output = gr.Markdown()
364
+
365
+ gr.Markdown("---")
366
+ gr.Markdown("### 📚 Example Videos")
367
+
368
+ # Create examples from your dataset (same as your testing UI)
369
+ example_videos = []
370
+ for i in range(min(3, len(full_df))):
371
+ if os.path.exists(full_df.iloc[i]['video_path']):
372
+ example_videos.append([full_df.iloc[i]['video_path']])
373
+
374
+ if example_videos:
375
+ gr.Examples(
376
+ examples=example_videos,
377
+ inputs=[video_input],
378
+ label="Try these example videos:",
379
+ # Optional: You can also add outputs if you want auto-prediction
380
+ # outputs=[results_output],
381
+ # fn=predict_video_clean,
382
+ )
383
+ else:
384
+ gr.Markdown("*No example videos available*")
385
 
386
  # Hidden states
387
  current_prediction = gr.State()