File size: 5,440 Bytes
5617612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# ==========================================
# image_processing_gpu.py - Version GPU avec TrOCR
# ==========================================

"""
Module de traitement d'images GPU-optimisé pour calculs mathématiques
Utilise TrOCR pour une précision maximale sur GPU
"""

import time
import torch
from utils import (
    optimize_image_for_ocr,
    prepare_image_for_dataset, 
    create_thumbnail_fast,
    create_white_canvas,
    log_memory_usage,
    cleanup_memory,
    decode_image_from_dataset,
    validate_ocr_result
)

# Variables globales pour OCR TrOCR
processor = None
model = None
OCR_MODEL_NAME = "TrOCR-base-handwritten"

def init_ocr_model() -> bool:
    """Initialise TrOCR (optimisé GPU)"""
    global processor, model
    
    try:
        print("🔄 Chargement TrOCR (GPU optimisé)...")
        from transformers import TrOCRProcessor, VisionEncoderDecoderModel
        
        processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
        model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
        
        # Optimisations GPU
        model.eval()
        
        if torch.cuda.is_available():
            model = model.cuda()
            device_info = f"GPU ({torch.cuda.get_device_name()})"
            print(f"✅ TrOCR prêt sur {device_info} !")
        else:
            device_info = "CPU (pas de GPU détecté)"
            print(f"⚠️ TrOCR sur CPU - {device_info}")
        
        return True
        
    except Exception as e:
        print(f"❌ Erreur lors du chargement TrOCR: {e}")
        return False

def get_ocr_model_info() -> dict:
    """Retourne les informations du modèle OCR utilisé"""
    device = "GPU" if torch.cuda.is_available() and model is not None else "CPU"
    gpu_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "N/A"
    
    return {
        "model_name": OCR_MODEL_NAME,
        "device": device,
        "gpu_name": gpu_name,
        "framework": "HuggingFace-Transformers",
        "optimized_for": "accuracy",
        "version": "microsoft/trocr-base-handwritten"
    }

def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
    """
    OCR avec TrOCR (GPU optimisé)
    
    Args:
        image_dict: Image d'entrée (format Gradio)
        debug: Afficher les logs de debug
        
    Returns:
        (résultat_ocr, image_optimisée, données_dataset)
    """
    if image_dict is None or processor is None or model is None:
        if debug:
            print("  ❌ Image manquante ou TrOCR non initialisé")
        return "0", None, None

    try:
        start_time = time.time()
        if debug:
            print("  🔄 Début OCR TrOCR...")
        
        # Optimiser image (fonction commune)
        optimized_image = optimize_image_for_ocr(image_dict, max_size=384)  # TrOCR préfère 384x384
        if optimized_image is None:
            if debug:
                print("  ❌ Échec optimisation image")
            return "0", None, None

        # TrOCR - traitement spécialisé GPU
        if debug:
            print("  🤖 Lancement TrOCR...")
        
        with torch.no_grad():
            # Preprocessing
            pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values
            
            # GPU transfer si disponible
            if torch.cuda.is_available():
                pixel_values = pixel_values.cuda()
            
            # Génération optimisée
            generated_ids = model.generate(
                pixel_values,
                max_length=4,           # Optimisé pour les calculs
                num_beams=1,            # Rapide
                do_sample=False,        # Déterministe
                early_stopping=True,    # Arrêt rapide
                pad_token_id=processor.tokenizer.pad_token_id
            )
            
            # Décodage
            result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            final_result = validate_ocr_result(result, max_length=4)
        
        # Préparer pour dataset (fonction commune)
        dataset_image_data = prepare_image_for_dataset(optimized_image)
        
        if debug:
            total_time = time.time() - start_time
            device = "GPU" if torch.cuda.is_available() else "CPU"
            print(f"  ✅ TrOCR ({device}) terminé en {total_time:.1f}s → '{final_result}'")
        
        return final_result, optimized_image, dataset_image_data

    except Exception as e:
        print(f"❌ Erreur OCR TrOCR: {e}")
        return "0", None, None

def recognize_number_fast(image_dict) -> tuple[str, any]:
    """Version rapide standard"""
    result, optimized_image, _ = recognize_number_fast_with_image(image_dict)
    return result, optimized_image

def recognize_number(image_dict) -> str:
    """Interface standard"""
    result, _ = recognize_number_fast(image_dict)
    return result

# Fonctions spécifiques au fine-tuning (pour plus tard)
def prepare_for_finetuning(dataset_path: str) -> dict:
    """Prépare le dataset pour le fine-tuning TrOCR"""
    # TODO: Implémenter quand on aura HF Pro
    return {"status": "ready_for_implementation"}

def quantize_model() -> bool:
    """Quantize le modèle TrOCR pour optimiser les performances CPU"""
    # TODO: Implémenter la quantization
    return False