File size: 4,886 Bytes
1e4b8a0
36961d0
1e4b8a0
 
cdceb7f
36961d0
cdceb7f
 
1e4b8a0
36961d0
 
 
1e4b8a0
 
 
cdceb7f
1e4b8a0
 
 
 
 
 
 
cdceb7f
 
 
 
 
 
36961d0
cdceb7f
 
 
36961d0
cdceb7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0316030
cdceb7f
 
36961d0
cdceb7f
36961d0
cdceb7f
 
 
 
 
 
 
 
 
 
 
 
1e4b8a0
36961d0
cdceb7f
1e4b8a0
36961d0
1e4b8a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdceb7f
1e4b8a0
 
cdceb7f
1e4b8a0
 
 
cdceb7f
1e4b8a0
 
 
 
 
 
 
 
 
cdceb7f
1e4b8a0
 
 
36961d0
cdceb7f
1e4b8a0
 
 
 
 
cdceb7f
 
1e4b8a0
cdceb7f
1e4b8a0
 
 
 
 
cdceb7f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# ==========================================
# image_processing_gpu.py - Version ZeroGPU simplifiée
# ==========================================

"""
Module de traitement d'images GPU-optimisé pour ZeroGPU HuggingFace Spaces
"""

import time
import torch
import spaces
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

from utils import (
    optimize_image_for_ocr,
    prepare_image_for_dataset, 
    create_thumbnail_fast,
    create_white_canvas,
    log_memory_usage,
    cleanup_memory,
    validate_ocr_result
)

# Variables globales pour OCR
processor = None
model = None
OCR_MODEL_NAME = "TrOCR-base-handwritten"

def init_ocr_model() -> bool:
    """Initialise TrOCR pour ZeroGPU"""
    global processor, model
    
    try:
        print("🔄 Chargement TrOCR (ZeroGPU)...")
        
        processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
        model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
        
        # Optimisations
        model.eval()
        
        if torch.cuda.is_available():
            model = model.cuda()
            device_info = f"GPU ({torch.cuda.get_device_name()})"
            print(f"✅ TrOCR prêt sur {device_info} !")
        else:
            device_info = "CPU (ZeroGPU pas encore alloué)"
            print(f"⚠️ TrOCR sur CPU - {device_info}")
        
        return True
        
    except Exception as e:
        print(f"❌ Erreur lors du chargement TrOCR: {e}")
        return False

def get_ocr_model_info() -> dict:
    """Retourne les informations du modèle OCR utilisé"""
    if torch.cuda.is_available():
        device = "ZeroGPU"
        gpu_name = torch.cuda.get_device_name()
    else:
        device = "CPU"
        gpu_name = "N/A"
    
    return {
        "model_name": OCR_MODEL_NAME,
        "device": device,
        "gpu_name": gpu_name,
        "framework": "HuggingFace-Transformers-ZeroGPU",
        "optimized_for": "accuracy",
        "version": "microsoft/trocr-base-handwritten"
    }

@spaces.GPU
def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
    """
    OCR avec TrOCR ZeroGPU - Version simplifiée
    """
    if image_dict is None:
        if debug:
            print("  ❌ Image manquante")
        return "0", None, None

    try:
        start_time = time.time()
        if debug:
            print("  🔄 Début OCR TrOCR ZeroGPU...")
        
        # Optimiser image
        optimized_image = optimize_image_for_ocr(image_dict, max_size=384)
        if optimized_image is None:
            if debug:
                print("  ❌ Échec optimisation image")
            return "0", None, None

        # TrOCR - traitement ZeroGPU
        if processor is None or model is None:
            if debug:
                print("  ❌ TrOCR non initialisé")
            return "0", None, None
            
        if debug:
            print("  🤖 Lancement TrOCR ZeroGPU...")
        
        with torch.no_grad():
            # Preprocessing
            pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values
            
            # GPU transfer si disponible
            if torch.cuda.is_available():
                pixel_values = pixel_values.cuda()
            
            # Génération optimisée
            generated_ids = model.generate(
                pixel_values,
                max_length=4,
                num_beams=1,
                do_sample=False,
                early_stopping=True,
                pad_token_id=processor.tokenizer.pad_token_id
            )
            
            # Décodage
            result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            final_result = validate_ocr_result(result, max_length=4)
        
        # Préparer pour dataset
        dataset_image_data = prepare_image_for_dataset(optimized_image)
        
        if debug:
            total_time = time.time() - start_time
            device = "ZeroGPU" if torch.cuda.is_available() else "CPU"
            print(f"  ✅ TrOCR ({device}) terminé en {total_time:.1f}s → '{final_result}'")
            if dataset_image_data:
                print(f"  🖼️ Image dataset: {type(dataset_image_data.get('handwriting_image', 'None'))}")
        
        return final_result, optimized_image, dataset_image_data

    except Exception as e:
        print(f"❌ Erreur OCR TrOCR ZeroGPU: {e}")
        return "0", None, None

def recognize_number_fast(image_dict) -> tuple[str, any]:
    """Version rapide standard"""
    result, optimized_image, _ = recognize_number_fast_with_image(image_dict)
    return result, optimized_image

def recognize_number(image_dict) -> str:
    """Interface standard"""
    result, _ = recognize_number_fast(image_dict)
    return result