Spaces:
Running
Running
mohakkapoor4
Refactor .gitignore to specify checkpoint file types and exclude all but the best model. Update inference.py to use enhanced CAPTCHA generation and adjust dimensions. Increase training epochs in train.py for better model performance. Update training metrics and data generation logic in data.py for improved dataset handling and augmentation. Update config.py for dataset path consistency.
322be7d
| import torch | |
| import cv2 | |
| import os | |
| import random | |
| import numpy as np | |
| from src.config import cfg | |
| from src.model_crnn import CRNN | |
| from src.vocab import ctc_greedy_decode, vocab_size | |
| from src.plotting import TrainingMetrics | |
| from src.generateCaptcha import generate_captcha | |
| def load_model(checkpoint_path="checkpoints/best_model.pth"): | |
| """Load the trained model from checkpoint.""" | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| # Detect available device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load checkpoint to the detected device | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| print(f"Loaded model from epoch {checkpoint['epoch']}") | |
| print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}") | |
| print(f" Loading to device: {device}") | |
| # Create model and load weights | |
| model = CRNN(vocab_size=vocab_size(), hidden=320, dropout=0.05) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| return model | |
| def preprocess_image(image_path, target_size=(cfg.W_max, cfg.H)): | |
| """Preprocess image for inference (same as training).""" | |
| # Load image | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| # Read and preprocess | |
| img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE if cfg.grayscale else cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise ValueError(f"Failed to load image: {image_path}") | |
| # Resize to target dimensions | |
| img = cv2.resize(img, target_size) | |
| # Convert to tensor and normalize | |
| img_tensor = torch.from_numpy(img).float() / 255.0 | |
| # Add batch and channel dimensions | |
| if cfg.grayscale: | |
| img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] | |
| else: | |
| img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W] | |
| return img_tensor | |
| def predict_captcha(model, image_tensor, device): | |
| """Run inference on a single image.""" | |
| with torch.no_grad(): | |
| # Move to device | |
| image_tensor = image_tensor.to(device) | |
| # Forward pass | |
| logits = model(image_tensor) | |
| # Decode prediction | |
| prediction = ctc_greedy_decode(logits) | |
| return prediction[0] if prediction else "" | |
| def generate_test_captcha(text, filename, width=256, height=60): | |
| """Generate a test CAPTCHA image using enhanced generation.""" | |
| # Use the enhanced CAPTCHA generation from generateCaptcha.py | |
| img = generate_captcha(text, width=width, height=height) | |
| # Ensure results directory exists | |
| os.makedirs(cfg.RESULT_DIR, exist_ok=True) | |
| filepath = os.path.join(cfg.RESULT_DIR, filename) | |
| img.save(filepath) | |
| print(f"Generated enhanced test CAPTCHA: {filename}") | |
| return filepath | |
| def main(): | |
| # Setup | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| os.makedirs(cfg.RESULT_DIR, exist_ok=True) | |
| try: | |
| # Load trained model | |
| print("Loading trained model...") | |
| model = load_model() | |
| model = model.to(device) | |
| print("Model loaded successfully!") | |
| # Generate test CAPTCHAs | |
| print("\nGenerating enhanced test CAPTCHAs...") | |
| test_cases = [] | |
| for i in range(4): | |
| # Generate random text | |
| text = ''.join(random.choices(cfg.chars, k=random.randint(cfg.CAPTCHA_LEN_LOWER_LIMIT, cfg.CAPTCHA_LEN_UPPER_LIMIT))) | |
| filename = f"{text}_{i}.png" | |
| # Generate enhanced image | |
| image_path = generate_test_captcha(text, filename, width=cfg.W_max, height=cfg.H) | |
| test_cases.append((text, image_path, "")) # Add empty prediction slot | |
| # Run inference | |
| print("\nRunning inference...") | |
| print("-" * 60) | |
| print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}") | |
| print("-" * 60) | |
| correct_count = 0 | |
| for i, (target_text, image_path, _) in enumerate(test_cases): | |
| try: | |
| # Preprocess image | |
| image_tensor = preprocess_image(image_path) | |
| # Run prediction | |
| prediction = predict_captcha(model, image_tensor, device) | |
| # Store prediction in test_cases | |
| test_cases[i] = (target_text, image_path, prediction) | |
| # Check if correct | |
| is_correct = prediction == target_text | |
| if is_correct: | |
| correct_count += 1 | |
| # Display result | |
| status = "CORRECT" if is_correct else "WRONG" | |
| print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}") | |
| except Exception as e: | |
| print(f"Error processing {image_path}: {e}") | |
| # Summary | |
| print("-" * 60) | |
| accuracy = (correct_count / len(test_cases)) * 100 | |
| print(f"Overall Accuracy: {correct_count}/{len(test_cases)} ({accuracy:.1f}%)") | |
| # Calculate individual character accuracy | |
| total_chars = 0 | |
| correct_chars = 0 | |
| for target_text, _, prediction in test_cases: | |
| total_chars += len(target_text) | |
| # Count correct characters (position by position) | |
| min_len = min(len(target_text), len(prediction)) | |
| for i in range(min_len): | |
| if target_text[i] == prediction[i]: | |
| correct_chars += 1 | |
| char_accuracy = (correct_chars / total_chars) * 100 if total_chars > 0 else 0 | |
| print(f"Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)") | |
| if accuracy >= 80: | |
| print("Excellent performance!") | |
| elif accuracy >= 60: | |
| print("Good performance!") | |
| else: | |
| print("Room for improvement...") | |
| # Create and save results plot | |
| print("\nGenerating results visualization...") | |
| try: | |
| metrics = TrainingMetrics() | |
| image_paths = [case[1] for case in test_cases] | |
| predictions = [case[2] for case in test_cases] | |
| targets = [case[0] for case in test_cases] | |
| # Create results directory if it doesn't exist | |
| os.makedirs("Metrics", exist_ok=True) | |
| # Plot results | |
| metrics.plot_results(image_paths, predictions, targets) | |
| print("Results plot generated successfully!") | |
| except Exception as e: | |
| print(f"Warning: Could not generate plot: {e}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| print("Make sure you have a trained model in checkpoints/best_model.pth") | |
| if __name__ == "__main__": | |
| main() | |