vedchamp07 commited on
Commit
3a3f6c6
·
1 Parent(s): fd44722

Add CAPTCHA breaker app

Browse files
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app for testing CAPTCHA model.
3
+ Allows uploading CAPTCHA images and getting predictions with preprocessing.
4
+ """
5
+ import gradio as gr
6
+ import torch
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ import string
10
+ from pathlib import Path
11
+ import numpy as np
12
+ import cv2
13
+
14
+ from src.model import CTCCaptchaModel
15
+
16
+
17
+ # Setup
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ CHARACTERS = string.digits + string.ascii_lowercase + string.ascii_uppercase
20
+ MODEL_PATH = Path("models/captcha_model_v3.pth")
21
+
22
+ # Load model
23
+ model = CTCCaptchaModel(num_classes=len(CHARACTERS), use_attention=True)
24
+
25
+ # Load checkpoint
26
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
27
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
28
+ model.load_state_dict(checkpoint['model_state_dict'])
29
+ else:
30
+ model.load_state_dict(checkpoint)
31
+
32
+ model.to(DEVICE)
33
+ model.eval()
34
+
35
+ # Image preprocessing transforms
36
+ transform = transforms.Compose([
37
+ transforms.Resize((60, 160)),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean=[0.5], std=[0.5])
40
+ ])
41
+
42
+
43
+ def preprocess_image(image):
44
+ """
45
+ Preprocess image: grayscale, denoising, and thresholding.
46
+
47
+ Args:
48
+ image: PIL Image
49
+
50
+ Returns:
51
+ Preprocessed PIL Image
52
+ """
53
+ # Convert to numpy array
54
+ img_array = np.array(image.convert('L'))
55
+
56
+ # Apply Otsu's thresholding
57
+ _, binary = cv2.threshold(img_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
58
+
59
+ # Morphological closing to remove noise
60
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
61
+ processed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
62
+
63
+ # Convert back to PIL Image
64
+ return Image.fromarray(processed)
65
+
66
+
67
+ def predict_captcha(image, ground_truth=""):
68
+ """
69
+ Predict CAPTCHA text from image with preprocessing.
70
+
71
+ Args:
72
+ image: PIL Image or numpy array
73
+ ground_truth: Optional ground truth text for comparison
74
+
75
+ Returns:
76
+ Tuple of (prediction result, preprocessed image)
77
+ """
78
+ try:
79
+ # Convert to PIL Image if numpy array
80
+ if isinstance(image, np.ndarray):
81
+ image = Image.fromarray(image)
82
+
83
+ # Resize image if not standard dimensions (60x160)
84
+ if image.size != (160, 60):
85
+ image = image.resize((160, 60), Image.LANCZOS)
86
+
87
+ # Preprocess image
88
+ processed_image = preprocess_image(image)
89
+
90
+ # Convert to tensor and predict
91
+ image_tensor = transform(processed_image).unsqueeze(0).to(DEVICE)
92
+
93
+ # Predict
94
+ with torch.no_grad():
95
+ pred_indices = model.predict(image_tensor)[0]
96
+
97
+ # Decode
98
+ predicted_text = ''.join([
99
+ CHARACTERS[idx.item()] for idx in pred_indices
100
+ if idx.item() < len(CHARACTERS)
101
+ ])
102
+
103
+ # Format output with styling
104
+ result = f"### 🎯 Prediction Result\n\n"
105
+ result += f"# **{predicted_text}**\n\n"
106
+ result += f"*Length: {len(predicted_text)} characters*\n\n"
107
+
108
+ if ground_truth.strip():
109
+ ground_truth = ground_truth # Keep case sensitive
110
+ is_correct = predicted_text == ground_truth
111
+ result += f"**Expected:** {ground_truth}\n\n"
112
+ if is_correct:
113
+ result += "## ✅ **CORRECT!**"
114
+ else:
115
+ result += f"## ❌ **INCORRECT**"
116
+
117
+ return result, processed_image
118
+
119
+ except Exception as e:
120
+ return f"❌ **Error:** {str(e)}", None
121
+
122
+
123
+ def extract_from_filename(filename):
124
+ """Extract text from CAPTCHA filename (format: TEXT_INDEX.png)."""
125
+ if filename and hasattr(filename, 'name'):
126
+ stem = Path(filename.name).stem
127
+ text = stem.split('_')[0]
128
+ return text
129
+ return ""
130
+
131
+
132
+ # Create Gradio interface
133
+ with gr.Blocks(title="🔐 CAPTCHA Breaker", theme=gr.themes.Soft()) as demo:
134
+ gr.Markdown("""
135
+ <div style="text-align: center; padding: 20px;">
136
+
137
+ # 🔐 CAPTCHA Breaker
138
+
139
+ ### Advanced AI-Powered CAPTCHA Recognition
140
+
141
+ Powered by **CNN + LSTM + Self-Attention** neural network
142
+
143
+ </div>
144
+ """)
145
+
146
+ with gr.Row():
147
+ with gr.Column(scale=2):
148
+ gr.Markdown("#### 📸 Upload Your CAPTCHA")
149
+ image_input = gr.Image(
150
+ type="pil",
151
+ label="Drop CAPTCHA image here",
152
+ image_mode="L"
153
+ )
154
+
155
+ with gr.Row():
156
+ ground_truth_input = gr.Textbox(
157
+ label="Expected Answer (optional)",
158
+ placeholder="Type here to verify accuracy",
159
+ lines=1,
160
+ scale=3
161
+ )
162
+ predict_button = gr.Button(
163
+ "🔍 Decode",
164
+ variant="primary",
165
+ scale=1
166
+ )
167
+
168
+ with gr.Column(scale=2):
169
+ gr.Markdown("#### 🎯 Results")
170
+ output = gr.Markdown(
171
+ "<div style='text-align: center; padding: 40px; color: #888;'>Upload an image to get started</div>"
172
+ )
173
+
174
+ with gr.Row():
175
+ with gr.Column():
176
+ gr.Markdown("#### 🔬 Preprocessing Steps Applied:")
177
+ gr.Markdown("""
178
+ - ✓ Auto-resize to 60×160 (if needed)
179
+ - ✓ Grayscale conversion
180
+ - ✓ Otsu's thresholding
181
+ - ✓ Morphological closing (denoising)
182
+ - ✓ Tensor normalization
183
+ - ✓ Variable length support (3-7 chars)
184
+ - ✓ Lowercase + Uppercase + Digits
185
+ """)
186
+
187
+ with gr.Column():
188
+ gr.Markdown("#### 📊 Character Set:")
189
+ gr.Markdown("""
190
+ - **Digits:** 0-9
191
+ - **Lowercase:** a-z
192
+ - **Uppercase:** A-Z
193
+ - **Total:** 62 characters
194
+ """)
195
+
196
+ with gr.Column():
197
+ gr.Markdown("#### 🖼️ Processed Image:")
198
+ preprocessed_image = gr.Image(
199
+ label="Input After Preprocessing",
200
+ type="pil"
201
+ )
202
+
203
+ # Info section
204
+ with gr.Accordion("ℹ️ Model Architecture & Performance", open=False):
205
+ gr.Markdown("""
206
+ ### 🏗️ Architecture
207
+
208
+ ```
209
+ Input Image (1, 60, 160) [Auto-resized if needed]
210
+
211
+ CNN: 4 Convolutional Blocks
212
+ • Progressive feature extraction
213
+ • 1→32→64→128→256 channels
214
+
215
+ Bidirectional LSTM: 2 layers
216
+ • 256 hidden units each direction
217
+ • Learns sequential dependencies
218
+
219
+ Self-Attention: 4 heads
220
+ • Refines character representations
221
+ • Improves focus on important features
222
+
223
+ CTC Loss: Automatic Alignment
224
+ • No bounding boxes needed!
225
+ • Learns character positions automatically
226
+
227
+ Output: Variable-length prediction (3-7 characters)
228
+ ```
229
+
230
+ ### 📈 Model Capabilities (v3)
231
+
232
+ | Feature | Details |
233
+ |---------|---------|
234
+ | **Model Version** | v3 (Latest) |
235
+ | **Text Length** | 3-7 characters (variable) |
236
+ | **Character Set** | 0-9, a-z, A-Z (62 total) |
237
+ | **Architecture** | CNN + LSTM + Attention |
238
+ | **Training Data** | 10,000 synthetic CAPTCHAs |
239
+ | **Image Resize** | Automatic (any size → 60×160) |
240
+
241
+ ### ⚠️ Known Limitations
242
+
243
+ - 0 vs O confusion (visual similarity)
244
+ - i vs l vs 1 confusion (very similar shapes)
245
+ - Limited performance on decorative/stylized fonts
246
+ - Sensitive to extreme image distortions
247
+ """)
248
+
249
+ # Connect buttons to prediction function
250
+ predict_button.click(
251
+ fn=predict_captcha,
252
+ inputs=[image_input, ground_truth_input],
253
+ outputs=[output, preprocessed_image]
254
+ )
255
+
256
+ # Auto-predict on image upload
257
+ image_input.change(
258
+ fn=lambda img: predict_captcha(img, ""),
259
+ inputs=image_input,
260
+ outputs=[output, preprocessed_image]
261
+ )
262
+
263
+ # Footer
264
+ gr.Markdown("""
265
+ ---
266
+ <div style="text-align: center; color: #999; padding: 20px;">
267
+ Built with PyTorch | Device: {device} | GitHub: vedchamp07/captcha-breaker
268
+ </div>
269
+ """.format(device=DEVICE))
270
+
271
+
272
+ if __name__ == "__main__":
273
+ demo.launch(share=True)
models/captcha_model_v3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e724f2d10b44f23f6794de5aa316b809388006f66eb39059851b6cd750e6de4
3
+ size 20361923
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ captcha
4
+ Pillow
5
+ numpy
6
+ matplotlib
7
+ tqdm
8
+ opencv-python # For preprocessing (grayscale, noise removal)
9
+ gradio # For interactive web app
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (145 Bytes). View file
 
src/__pycache__/model.cpython-314.pyc ADDED
Binary file (10.6 kB). View file
 
src/model.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CTC-based CAPTCHA recognition model.
3
+ Uses CNN + LSTM + CTC loss - no bounding boxes needed!
4
+
5
+ This approach is standard for sequence recognition tasks where
6
+ character positions are unknown or variable.
7
+ """
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ class CTCCaptchaModel(nn.Module):
13
+ """
14
+ CAPTCHA recognition using CTC (Connectionist Temporal Classification).
15
+
16
+ Architecture:
17
+ 1. CNN backbone extracts visual features
18
+ 2. Reshape to sequence (treating width as time steps)
19
+ 3. Bidirectional LSTM processes sequence
20
+ 4. Linear layer outputs character probabilities for each time step
21
+ 5. CTC loss handles alignment between predictions and ground truth
22
+
23
+ No need for bounding boxes - CTC figures out alignment automatically!
24
+ """
25
+
26
+ def __init__(self, num_classes=36, hidden_size=256, num_lstm_layers=2, use_attention=False):
27
+ """
28
+ Args:
29
+ num_classes: Number of character classes (36 for A-Z, 0-9)
30
+ hidden_size: Hidden size for LSTM layers
31
+ num_lstm_layers: Number of LSTM layers
32
+ """
33
+ super(CTCCaptchaModel, self).__init__()
34
+
35
+ self.num_classes = num_classes
36
+ # CTC needs blank token for alignment (class index = num_classes)
37
+ self.blank_idx = num_classes
38
+
39
+ # CNN backbone for feature extraction
40
+ # Input: (batch, 1, 60, 160) - grayscale image
41
+ self.cnn = nn.Sequential(
42
+ # Block 1
43
+ nn.Conv2d(1, 32, kernel_size=3, padding=1),
44
+ nn.BatchNorm2d(32),
45
+ nn.ReLU(),
46
+ nn.MaxPool2d(2, 2), # -> (32, 30, 80)
47
+
48
+ # Block 2
49
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
50
+ nn.BatchNorm2d(64),
51
+ nn.ReLU(),
52
+ nn.MaxPool2d(2, 2), # -> (64, 15, 40)
53
+
54
+ # Block 3
55
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
56
+ nn.BatchNorm2d(128),
57
+ nn.ReLU(),
58
+ nn.MaxPool2d((1, 2)), # Pool only width -> (128, 15, 20)
59
+
60
+ # Block 4
61
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
62
+ nn.BatchNorm2d(256),
63
+ nn.ReLU(),
64
+ nn.MaxPool2d((1, 2)), # Pool only width -> (256, 15, 10)
65
+ )
66
+
67
+ # After CNN: (batch, 256, 15, 10)
68
+ # We'll reshape to: (batch, 10, 256*15) treating width as sequence
69
+ # So sequence length = 10, feature dim = 256*15 = 3840
70
+ self.feature_size = 256 * 15 # channels * height
71
+ self.sequence_length = 10 # width after pooling
72
+
73
+ # Map CNN features to LSTM input size
74
+ self.map_to_seq = nn.Linear(self.feature_size, hidden_size)
75
+
76
+ # Bidirectional LSTM to process sequence
77
+ self.lstm = nn.LSTM(
78
+ hidden_size,
79
+ hidden_size,
80
+ num_layers=num_lstm_layers,
81
+ bidirectional=True,
82
+ dropout=0.3 if num_lstm_layers > 1 else 0,
83
+ batch_first=True
84
+ )
85
+
86
+ # Optional self-attention on top of LSTM outputs
87
+ self.use_attention = use_attention
88
+ if self.use_attention:
89
+ self.attn = nn.MultiheadAttention(hidden_size * 2, num_heads=4, dropout=0.1, batch_first=True)
90
+ self.attn_norm = nn.LayerNorm(hidden_size * 2)
91
+ self.attn_dropout = nn.Dropout(0.1)
92
+ else:
93
+ self.attn = None
94
+
95
+ # Output layer: map LSTM outputs to character probabilities
96
+ # +1 for CTC blank token
97
+ self.fc = nn.Linear(hidden_size * 2, num_classes + 1) # *2 for bidirectional
98
+
99
+ def forward(self, x):
100
+ """
101
+ Args:
102
+ x: Input images (batch_size, 1, 60, 160)
103
+
104
+ Returns:
105
+ Log probabilities for CTC loss (sequence_length, batch_size, num_classes+1)
106
+ """
107
+ batch_size = x.size(0)
108
+
109
+ # Extract CNN features
110
+ features = self.cnn(x) # (batch, 256, 15, 10)
111
+
112
+ # Reshape to sequence: (batch, width, channels*height)
113
+ # Transpose to treat width as sequence dimension
114
+ features = features.permute(0, 3, 1, 2) # (batch, 10, 256, 15)
115
+ features = features.reshape(batch_size, self.sequence_length, self.feature_size)
116
+
117
+ # Map to LSTM input size
118
+ features = self.map_to_seq(features) # (batch, 10, hidden_size)
119
+
120
+ # Process with LSTM
121
+ lstm_out, _ = self.lstm(features) # (batch, 10, hidden_size*2)
122
+
123
+ # Optional attention
124
+ if self.attn is not None:
125
+ attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
126
+ lstm_out = self.attn_norm(lstm_out + self.attn_dropout(attn_out))
127
+
128
+ # Get character predictions for each time step
129
+ logits = self.fc(lstm_out) # (batch, 10, num_classes+1)
130
+
131
+ # CTC expects: (sequence_length, batch, num_classes)
132
+ logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
133
+
134
+ # Apply log_softmax for CTC loss
135
+ log_probs = torch.nn.functional.log_softmax(logits, dim=2)
136
+
137
+ return log_probs
138
+
139
+ def predict(self, x):
140
+ """
141
+ Decode predictions using greedy decoding.
142
+
143
+ Args:
144
+ x: Input images (batch_size, 1, 60, 160)
145
+
146
+ Returns:
147
+ Predicted character indices (batch_size, max_length)
148
+ """
149
+ self.eval()
150
+ with torch.no_grad():
151
+ log_probs = self.forward(x) # (seq_len, batch, num_classes+1)
152
+
153
+ # Greedy decoding: take argmax at each time step
154
+ _, preds = log_probs.max(2) # (seq_len, batch)
155
+ preds = preds.transpose(0, 1) # (batch, seq_len)
156
+
157
+ # Decode: remove blanks and repeated characters
158
+ decoded = []
159
+ for pred_seq in preds:
160
+ decoded_seq = []
161
+ prev_char = None
162
+
163
+ for char_idx in pred_seq:
164
+ char_idx = char_idx.item()
165
+
166
+ # Skip blank tokens
167
+ if char_idx == self.blank_idx:
168
+ prev_char = None
169
+ continue
170
+
171
+ # Skip repeated characters (CTC rule)
172
+ if char_idx != prev_char:
173
+ decoded_seq.append(char_idx)
174
+ prev_char = char_idx
175
+
176
+ decoded.append(decoded_seq)
177
+
178
+ # Pad sequences to same length (max 5 for CAPTCHA)
179
+ max_len = 5
180
+ padded = []
181
+ for seq in decoded:
182
+ if len(seq) < max_len:
183
+ seq = seq + [0] * (max_len - len(seq)) # Pad with 0
184
+ else:
185
+ seq = seq[:max_len] # Truncate if too long
186
+ padded.append(seq)
187
+
188
+ # Return tensor on same device as input
189
+ return torch.tensor(padded, dtype=torch.long, device=x.device)
190
+
191
+
192
+ class CTCCaptchaModelSimple(nn.Module):
193
+ """
194
+ Simpler CTC model without LSTM (faster training, less memory).
195
+ Good baseline to start with.
196
+ """
197
+
198
+ def __init__(self, num_classes=36):
199
+ super(CTCCaptchaModelSimple, self).__init__()
200
+
201
+ self.num_classes = num_classes
202
+ self.blank_idx = num_classes
203
+
204
+ # CNN backbone
205
+ self.features = nn.Sequential(
206
+ nn.Conv2d(1, 64, kernel_size=3, padding=1),
207
+ nn.BatchNorm2d(64),
208
+ nn.ReLU(),
209
+ nn.MaxPool2d((2, 2)), # -> (64, 30, 80)
210
+
211
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
212
+ nn.BatchNorm2d(128),
213
+ nn.ReLU(),
214
+ nn.MaxPool2d((2, 2)), # -> (128, 15, 40)
215
+
216
+ nn.Conv2d(128, 256, kernel_size=3, padding=1),
217
+ nn.BatchNorm2d(256),
218
+ nn.ReLU(),
219
+ nn.MaxPool2d((1, 2)), # -> (256, 15, 20)
220
+
221
+ nn.Conv2d(256, 512, kernel_size=3, padding=1),
222
+ nn.BatchNorm2d(512),
223
+ nn.ReLU(),
224
+ nn.MaxPool2d((1, 2)), # -> (512, 15, 10)
225
+ )
226
+
227
+ # Direct mapping to character predictions
228
+ # Treat width dimension as sequence
229
+ self.classifier = nn.Sequential(
230
+ nn.Linear(512 * 15, 256),
231
+ nn.ReLU(),
232
+ nn.Dropout(0.3),
233
+ nn.Linear(256, num_classes + 1)
234
+ )
235
+
236
+ self.sequence_length = 10
237
+
238
+ def forward(self, x):
239
+ """Forward pass for CTC."""
240
+ batch_size = x.size(0)
241
+
242
+ # Extract features
243
+ features = self.features(x) # (batch, 512, 15, 10)
244
+
245
+ # Reshape: treat width as sequence
246
+ features = features.permute(0, 3, 1, 2) # (batch, 10, 512, 15)
247
+ features = features.reshape(batch_size, self.sequence_length, -1)
248
+
249
+ # Classify each time step
250
+ logits = self.classifier(features) # (batch, 10, num_classes+1)
251
+
252
+ # CTC format
253
+ logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
254
+ log_probs = torch.nn.functional.log_softmax(logits, dim=2)
255
+
256
+ return log_probs
257
+
258
+ def predict(self, x):
259
+ """Greedy decoding."""
260
+ self.eval()
261
+ with torch.no_grad():
262
+ log_probs = self.forward(x)
263
+ _, preds = log_probs.max(2)
264
+ preds = preds.transpose(0, 1)
265
+
266
+ # Decode
267
+ decoded = []
268
+ for pred_seq in preds:
269
+ decoded_seq = []
270
+ prev_char = None
271
+
272
+ for char_idx in pred_seq:
273
+ char_idx = char_idx.item()
274
+ if char_idx == self.blank_idx:
275
+ prev_char = None
276
+ continue
277
+ if char_idx != prev_char:
278
+ decoded_seq.append(char_idx)
279
+ prev_char = char_idx
280
+
281
+ decoded.append(decoded_seq)
282
+
283
+ # Pad to length 5
284
+ max_len = 5
285
+ padded = []
286
+ for seq in decoded:
287
+ if len(seq) < max_len:
288
+ seq = seq + [0] * (max_len - len(seq))
289
+ else:
290
+ seq = seq[:max_len]
291
+ padded.append(seq)
292
+
293
+ # Return tensor on same device as input
294
+ return torch.tensor(padded, dtype=torch.long, device=x.device)