Bliss-Ruth commited on
Commit
66dc1a0
·
verified ·
1 Parent(s): 24093cb

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -31
app.py CHANGED
@@ -8,6 +8,9 @@ import numpy as np
8
  from PIL import Image
9
  import tempfile
10
  import os
 
 
 
11
 
12
  # Your exact model class
13
  class XCLIPSignLanguageClassifier(nn.Module):
@@ -38,15 +41,21 @@ processor = XCLIPProcessor.from_pretrained("microsoft/xclip-base-patch32")
38
  # Load your trained model
39
  try:
40
  checkpoint = torch.load("best_xclip_model.pth", map_location=device, weights_only=False)
41
- model = XCLIPSignLanguageClassifier(num_classes=len(checkpoint["label_to_id"])).to(device)
42
  model.load_state_dict(checkpoint["model_state_dict"])
43
  model.eval()
44
  id_to_label = checkpoint["id_to_label"]
 
45
  print(f"✅ Model loaded! Can recognize {len(id_to_label)} signs: {list(id_to_label.values())}")
46
  except Exception as e:
47
  print(f"❌ Error loading model: {e}")
48
  exit(1)
49
 
 
 
 
 
 
50
  def extract_frames(video_path, num_frames=8):
51
  """Extract frames from video file"""
52
  try:
@@ -76,24 +85,8 @@ def extract_frames(video_path, num_frames=8):
76
  print(f"Frame extraction error: {e}")
77
  return [Image.new("RGB", (224, 224), (128, 128, 128)) for _ in range(num_frames)]
78
 
79
- def predict_video(video_file, user_correction=None):
80
- """Predict sign language from uploaded video"""
81
- try:
82
- # Get prediction
83
- predicted_label, confidence = predict_sign(video_file, model, processor, id_to_label, device)
84
-
85
- # Format results - EXACT SAME as our Colab interface
86
- result = f"🎯 **Prediction**: {predicted_label}\n"
87
- result += f"📊 **Confidence**: {confidence*100:.1f}%\n"
88
- result += f"🔍 **Model**: X-CLIP Fine-tuned"
89
-
90
- return result
91
-
92
- except Exception as e:
93
- return f"❌ Error processing video: {str(e)}"
94
-
95
  def predict_sign(video_path, model, processor, id_to_label, device):
96
- """Core prediction function"""
97
  try:
98
  # Sample frames
99
  frames = extract_frames(video_path)
@@ -110,23 +103,190 @@ def predict_sign(video_path, model, processor, id_to_label, device):
110
  logits = model(input_ids, attention_mask, pixel_values)
111
  probs = torch.softmax(logits, dim=1)
112
  confidence, pred_class = torch.max(probs, 1)
 
 
 
113
 
114
- return id_to_label[pred_class.item()], confidence.item()
 
 
 
 
 
 
 
 
115
 
116
  except Exception as e:
117
  print(f"❌ Prediction error: {e}")
118
- return "Unknown", 0.0
119
-
120
- # Create the interface - EXACT SAME as our Colab version
121
- demo = gr.Interface(
122
- fn=predict_video,
123
- inputs=gr.Video(label="📹 Upload Sign Language Video"),
124
- outputs=gr.Markdown(label=" Prediction Results"),
125
- title="🤟 Ugandan Sign Language Recognition",
126
- description="Upload a video of sign language and the AI will predict which sign it is!",
127
- examples=[] # You can add example videos later
128
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  # For Hugging Face Spaces deployment
131
  if __name__ == "__main__":
132
- demo.launch(share=True)
 
 
 
 
8
  from PIL import Image
9
  import tempfile
10
  import os
11
+ import json
12
+ from datetime import datetime
13
+ import pandas as pd
14
 
15
  # Your exact model class
16
  class XCLIPSignLanguageClassifier(nn.Module):
 
41
  # Load your trained model
42
  try:
43
  checkpoint = torch.load("best_xclip_model.pth", map_location=device, weights_only=False)
44
+ model = XCLIPSignLanguageClassifier(num_classes=len(checkpoint["id_to_label"])).to(device)
45
  model.load_state_dict(checkpoint["model_state_dict"])
46
  model.eval()
47
  id_to_label = checkpoint["id_to_label"]
48
+ label_to_id = {v: k for k, v in id_to_label.items()}
49
  print(f"✅ Model loaded! Can recognize {len(id_to_label)} signs: {list(id_to_label.values())}")
50
  except Exception as e:
51
  print(f"❌ Error loading model: {e}")
52
  exit(1)
53
 
54
+ # Continuous learning storage
55
+ FEEDBACK_FILE = "user_feedback.csv"
56
+ if not os.path.exists(FEEDBACK_FILE):
57
+ pd.DataFrame(columns=['timestamp', 'video_path', 'predicted_label', 'correct_label', 'confidence']).to_csv(FEEDBACK_FILE, index=False)
58
+
59
  def extract_frames(video_path, num_frames=8):
60
  """Extract frames from video file"""
61
  try:
 
85
  print(f"Frame extraction error: {e}")
86
  return [Image.new("RGB", (224, 224), (128, 128, 128)) for _ in range(num_frames)]
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def predict_sign(video_path, model, processor, id_to_label, device):
89
+ """Core prediction function with detailed outputs"""
90
  try:
91
  # Sample frames
92
  frames = extract_frames(video_path)
 
103
  logits = model(input_ids, attention_mask, pixel_values)
104
  probs = torch.softmax(logits, dim=1)
105
  confidence, pred_class = torch.max(probs, 1)
106
+
107
+ # Get all probabilities for detailed analysis
108
+ all_probs = probs.cpu().numpy()[0]
109
 
110
+ predicted_label = id_to_label[pred_class.item()]
111
+ confidence_value = confidence.item()
112
+
113
+ # Create confidence breakdown
114
+ confidence_details = []
115
+ for i, prob in enumerate(all_probs):
116
+ confidence_details.append(f"{id_to_label[i]}: {prob*100:.1f}%")
117
+
118
+ return predicted_label, confidence_value, confidence_details, all_probs
119
 
120
  except Exception as e:
121
  print(f"❌ Prediction error: {e}")
122
+ return "Unknown", 0.0, [], []
123
+
124
+ def save_feedback(video_path, predicted_label, correct_label, confidence):
125
+ """Save user feedback for continuous learning"""
126
+ try:
127
+ feedback_data = {
128
+ 'timestamp': datetime.now().isoformat(),
129
+ 'video_path': video_path,
130
+ 'predicted_label': predicted_label,
131
+ 'correct_label': correct_label,
132
+ 'confidence': confidence
133
+ }
134
+
135
+ # Append to CSV
136
+ df = pd.read_csv(FEEDBACK_FILE)
137
+ df = pd.concat([df, pd.DataFrame([feedback_data])], ignore_index=True)
138
+ df.to_csv(FEEDBACK_FILE, index=False)
139
+
140
+ return f"✅ Feedback saved! We'll use this to improve the model."
141
+ except Exception as e:
142
+ return f"❌ Error saving feedback: {str(e)}"
143
+
144
+ def predict_video(video_file):
145
+ """Predict sign language from uploaded video with detailed results"""
146
+ try:
147
+ if video_file is None:
148
+ return "## 📹 Please upload a video file", "", gr.update(visible=False)
149
+
150
+ # Get detailed prediction
151
+ predicted_label, confidence, confidence_details, all_probs = predict_sign(
152
+ video_file, model, processor, id_to_label, device
153
+ )
154
+
155
+ # Create detailed results
156
+ result = f"""
157
+ ## **Sign Language Translation Result**:
158
+
159
+ ### **Detected Sign:** {predicted_label}
160
+
161
+ ### **Confidence Level:** {confidence*100:.1f}%
162
+
163
+ ### **Translation:** This sign means "{predicted_label}" in Ugandan Sign Language
164
+
165
+ ---
166
+
167
+ ## Detailed Analysis:
168
+
169
+ **Confidence Breakdown:**
170
+ """
171
+
172
+ # Add confidence bars for each class
173
+ for i, (label, prob) in enumerate(zip(id_to_label.values(), all_probs)):
174
+ bar_length = int(prob * 20) # Scale to 20 characters
175
+ bar = "█" * bar_length + "░" * (20 - bar_length)
176
+ result += f"\n**{label}:** {bar} {prob*100:.1f}%"
177
+
178
+ result += f"""
179
+
180
+ ---
181
+
182
+ ### 🔧 Model Information:
183
+ - **Model:** X-CLIP Fine-tuned on Ugandan Sign Language
184
+ - **Supported Signs:** {len(id_to_label)} classes
185
+ - **Top Confidence:** {confidence*100:.1f}%
186
+ - **All Classes:** {', '.join(id_to_label.values())}
187
+
188
+ ---
189
+
190
+ **🤔 Was this prediction correct?** Use the feedback section below to help improve the model!
191
+ """
192
+
193
+ # Show feedback section
194
+ feedback_section = gr.update(visible=True)
195
+
196
+ return result, predicted_label, feedback_section
197
+
198
+ except Exception as e:
199
+ return f"## ❌ Error Processing Video\n\n**Error:** {str(e)}\n\nPlease try another video file.", "", gr.update(visible=False)
200
+
201
+ def submit_feedback(predicted_label, user_correction, video_path):
202
+ """Handle user feedback for continuous learning"""
203
+ if user_correction == "" or user_correction is None:
204
+ return "⚠️ Please select the correct sign label"
205
+
206
+ if user_correction == predicted_label:
207
+ return "✅ Thank you for confirming the prediction was correct!"
208
+
209
+ # Save correction feedback
210
+ result = save_feedback(video_path, predicted_label, user_correction, 0.0)
211
+
212
+ # Additional improvement message
213
+ result += f"\n\n📈 **Model Improvement:** The model will learn from this correction!"
214
+ result += f"\n**Wrong:** {predicted_label} → **Correct:** {user_correction}"
215
+ result += f"\n\n💡 This feedback will be used to retrain and improve the model accuracy."
216
+
217
+ return result
218
+
219
+ # Create the enhanced interface
220
+ with gr.Blocks(theme=gr.themes.Soft(), title="Ugandan Sign Language Translator") as demo:
221
+ gr.Markdown("""
222
+ # 🤟 Ugandan Sign Language Translation Tool
223
+
224
+ **Upload a video of Ugandan Sign Language and get instant translation with detailed analysis!**
225
+
226
+ *Supported signs: hello, how, good, please, sign language*
227
+ """)
228
+
229
+ with gr.Row():
230
+ with gr.Column():
231
+ video_input = gr.Video(
232
+ label="📹 Upload Sign Language Video",
233
+ sources=["upload"],
234
+ type="filepath"
235
+ )
236
+ predict_btn = gr.Button("🚀 Analyze Sign Language", variant="primary")
237
+
238
+ with gr.Column():
239
+ results_output = gr.Markdown(
240
+ label="🎯 Translation Results",
241
+ value="## 📤 Upload a video to get started..."
242
+ )
243
+
244
+ # Hidden state for current prediction
245
+ current_prediction = gr.State()
246
+ current_video_path = gr.State()
247
+
248
+ # Feedback section (initially hidden)
249
+ with gr.Row(visible=False) as feedback_row:
250
+ with gr.Column():
251
+ gr.Markdown("## 💡 Help Improve The Model")
252
+ correction_dropdown = gr.Dropdown(
253
+ choices=list(id_to_label.values()),
254
+ label="What was the correct sign?",
255
+ info="Select the actual sign in the video"
256
+ )
257
+ feedback_btn = gr.Button("📈 Submit Correction", variant="secondary")
258
+ feedback_output = gr.Markdown()
259
+
260
+ # Prediction logic
261
+ predict_btn.click(
262
+ fn=predict_video,
263
+ inputs=[video_input],
264
+ outputs=[results_output, current_prediction, feedback_row]
265
+ ).then(
266
+ lambda video: video,
267
+ inputs=[video_input],
268
+ outputs=[current_video_path]
269
+ )
270
+
271
+ # Feedback logic
272
+ feedback_btn.click(
273
+ fn=submit_feedback,
274
+ inputs=[current_prediction, correction_dropdown, current_video_path],
275
+ outputs=[feedback_output]
276
+ )
277
+
278
+ # Examples
279
+ gr.Markdown("### 📚 How to use:")
280
+ gr.Markdown("""
281
+ 1. **Upload** a video of someone performing sign language
282
+ 2. **Click Analyze** to get the translation
283
+ 3. **Review** the detailed confidence analysis
284
+ 4. **Provide feedback** if the prediction was wrong (this helps improve the model!)
285
+ """)
286
 
287
  # For Hugging Face Spaces deployment
288
  if __name__ == "__main__":
289
+ demo.launch(
290
+ share=True,
291
+ show_error=True
292
+ )