File size: 7,039 Bytes
04e423f
 
 
 
 
 
 
 
 
322be7d
04e423f
 
 
 
 
 
 
 
 
 
 
858aaf1
04e423f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322be7d
 
 
 
 
 
 
 
04e423f
322be7d
 
04e423f
 
 
 
 
858aaf1
04e423f
 
 
 
 
858aaf1
04e423f
 
858aaf1
04e423f
 
322be7d
04e423f
 
 
 
 
 
 
322be7d
 
04e423f
 
 
858aaf1
04e423f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858aaf1
04e423f
 
 
858aaf1
04e423f
 
 
 
858aaf1
04e423f
 
 
 
 
 
 
 
 
 
 
 
 
322be7d
04e423f
 
858aaf1
04e423f
858aaf1
04e423f
858aaf1
04e423f
 
858aaf1
04e423f
 
 
 
 
 
 
 
 
 
 
858aaf1
04e423f
 
858aaf1
04e423f
 
858aaf1
 
04e423f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import torch
import cv2
import os
import random
import numpy as np
from src.config import cfg
from src.model_crnn import CRNN
from src.vocab import ctc_greedy_decode, vocab_size
from src.plotting import TrainingMetrics
from src.generateCaptcha import generate_captcha

def load_model(checkpoint_path="checkpoints/best_model.pth"):
    """Load the trained model from checkpoint."""
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    # Detect available device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load checkpoint to the detected device
    checkpoint = torch.load(checkpoint_path, map_location=device)
    print(f"Loaded model from epoch {checkpoint['epoch']}")
    print(f"   Best validation loss: {checkpoint['best_val_loss']:.4f}")
    print(f"   Loading to device: {device}")
    
    # Create model and load weights
    model = CRNN(vocab_size=vocab_size(), hidden=320, dropout=0.05)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model

def preprocess_image(image_path, target_size=(cfg.W_max, cfg.H)):
    """Preprocess image for inference (same as training)."""
    # Load image
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found: {image_path}")
    
    # Read and preprocess
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE if cfg.grayscale else cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError(f"Failed to load image: {image_path}")
    
    # Resize to target dimensions
    img = cv2.resize(img, target_size)
    
    # Convert to tensor and normalize
    img_tensor = torch.from_numpy(img).float() / 255.0
    
    # Add batch and channel dimensions
    if cfg.grayscale:
        img_tensor = img_tensor.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
    else:
        img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0)  # [1, 3, H, W]
    
    return img_tensor

def predict_captcha(model, image_tensor, device):
    """Run inference on a single image."""
    with torch.no_grad():
        # Move to device
        image_tensor = image_tensor.to(device)
        
        # Forward pass
        logits = model(image_tensor)
        
        # Decode prediction
        prediction = ctc_greedy_decode(logits)
        
        return prediction[0] if prediction else ""

def generate_test_captcha(text, filename, width=256, height=60):
    """Generate a test CAPTCHA image using enhanced generation."""
    # Use the enhanced CAPTCHA generation from generateCaptcha.py
    img = generate_captcha(text, width=width, height=height)
    
    # Ensure results directory exists
    os.makedirs(cfg.RESULT_DIR, exist_ok=True)
    
    filepath = os.path.join(cfg.RESULT_DIR, filename)
    img.save(filepath)
    print(f"Generated enhanced test CAPTCHA: {filename}")
    return filepath

def main():
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    os.makedirs(cfg.RESULT_DIR, exist_ok=True)
    
    try:
        # Load trained model
        print("Loading trained model...")
        model = load_model()
        model = model.to(device)
        print("Model loaded successfully!")
        
        # Generate test CAPTCHAs
        print("\nGenerating enhanced test CAPTCHAs...")
        test_cases = []
        
        for i in range(4):
            # Generate random text
            text = ''.join(random.choices(cfg.chars, k=random.randint(cfg.CAPTCHA_LEN_LOWER_LIMIT, cfg.CAPTCHA_LEN_UPPER_LIMIT)))
            filename = f"{text}_{i}.png"
            
            # Generate enhanced image
            image_path = generate_test_captcha(text, filename, width=cfg.W_max, height=cfg.H)
            test_cases.append((text, image_path, ""))  # Add empty prediction slot
        
        # Run inference
        print("\nRunning inference...")
        print("-" * 60)
        print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}")
        print("-" * 60)
        
        correct_count = 0
        for i, (target_text, image_path, _) in enumerate(test_cases):
            try:
                # Preprocess image
                image_tensor = preprocess_image(image_path)
                
                # Run prediction
                prediction = predict_captcha(model, image_tensor, device)
                
                # Store prediction in test_cases
                test_cases[i] = (target_text, image_path, prediction)
                
                # Check if correct
                is_correct = prediction == target_text
                if is_correct:
                    correct_count += 1
                
                # Display result
                status = "CORRECT" if is_correct else "WRONG"
                print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}")
                
            except Exception as e:
                print(f"Error processing {image_path}: {e}")
        
        # Summary
        print("-" * 60)
        accuracy = (correct_count / len(test_cases)) * 100
        print(f"Overall Accuracy: {correct_count}/{len(test_cases)} ({accuracy:.1f}%)")
        
        # Calculate individual character accuracy
        total_chars = 0
        correct_chars = 0
        for target_text, _, prediction in test_cases:
            total_chars += len(target_text)
            # Count correct characters (position by position)
            min_len = min(len(target_text), len(prediction))
            for i in range(min_len):
                if target_text[i] == prediction[i]:
                    correct_chars += 1
        
        char_accuracy = (correct_chars / total_chars) * 100 if total_chars > 0 else 0
        print(f"Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)")
        
        if accuracy >= 80:
            print("Excellent performance!")
        elif accuracy >= 60:
            print("Good performance!")
        else:
            print("Room for improvement...")
        
        # Create and save results plot
        print("\nGenerating results visualization...")
        try:
            metrics = TrainingMetrics()
            image_paths = [case[1] for case in test_cases]
            predictions = [case[2] for case in test_cases]
            targets = [case[0] for case in test_cases]
            
            # Create results directory if it doesn't exist
            os.makedirs("Metrics", exist_ok=True)
            
            # Plot results
            metrics.plot_results(image_paths, predictions, targets)
            print("Results plot generated successfully!")
            
        except Exception as e:
            print(f"Warning: Could not generate plot: {e}")
            
    except Exception as e:
        print(f"Error: {e}")
        print("Make sure you have a trained model in checkpoints/best_model.pth")

if __name__ == "__main__":
    main()