Spaces:
Sleeping
Sleeping
File size: 5,440 Bytes
5617612 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# ==========================================
# image_processing_gpu.py - Version GPU avec TrOCR
# ==========================================
"""
Module de traitement d'images GPU-optimisé pour calculs mathématiques
Utilise TrOCR pour une précision maximale sur GPU
"""
import time
import torch
from utils import (
optimize_image_for_ocr,
prepare_image_for_dataset,
create_thumbnail_fast,
create_white_canvas,
log_memory_usage,
cleanup_memory,
decode_image_from_dataset,
validate_ocr_result
)
# Variables globales pour OCR TrOCR
processor = None
model = None
OCR_MODEL_NAME = "TrOCR-base-handwritten"
def init_ocr_model() -> bool:
"""Initialise TrOCR (optimisé GPU)"""
global processor, model
try:
print("🔄 Chargement TrOCR (GPU optimisé)...")
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
# Optimisations GPU
model.eval()
if torch.cuda.is_available():
model = model.cuda()
device_info = f"GPU ({torch.cuda.get_device_name()})"
print(f"✅ TrOCR prêt sur {device_info} !")
else:
device_info = "CPU (pas de GPU détecté)"
print(f"⚠️ TrOCR sur CPU - {device_info}")
return True
except Exception as e:
print(f"❌ Erreur lors du chargement TrOCR: {e}")
return False
def get_ocr_model_info() -> dict:
"""Retourne les informations du modèle OCR utilisé"""
device = "GPU" if torch.cuda.is_available() and model is not None else "CPU"
gpu_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "N/A"
return {
"model_name": OCR_MODEL_NAME,
"device": device,
"gpu_name": gpu_name,
"framework": "HuggingFace-Transformers",
"optimized_for": "accuracy",
"version": "microsoft/trocr-base-handwritten"
}
def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
"""
OCR avec TrOCR (GPU optimisé)
Args:
image_dict: Image d'entrée (format Gradio)
debug: Afficher les logs de debug
Returns:
(résultat_ocr, image_optimisée, données_dataset)
"""
if image_dict is None or processor is None or model is None:
if debug:
print(" ❌ Image manquante ou TrOCR non initialisé")
return "0", None, None
try:
start_time = time.time()
if debug:
print(" 🔄 Début OCR TrOCR...")
# Optimiser image (fonction commune)
optimized_image = optimize_image_for_ocr(image_dict, max_size=384) # TrOCR préfère 384x384
if optimized_image is None:
if debug:
print(" ❌ Échec optimisation image")
return "0", None, None
# TrOCR - traitement spécialisé GPU
if debug:
print(" 🤖 Lancement TrOCR...")
with torch.no_grad():
# Preprocessing
pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values
# GPU transfer si disponible
if torch.cuda.is_available():
pixel_values = pixel_values.cuda()
# Génération optimisée
generated_ids = model.generate(
pixel_values,
max_length=4, # Optimisé pour les calculs
num_beams=1, # Rapide
do_sample=False, # Déterministe
early_stopping=True, # Arrêt rapide
pad_token_id=processor.tokenizer.pad_token_id
)
# Décodage
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
final_result = validate_ocr_result(result, max_length=4)
# Préparer pour dataset (fonction commune)
dataset_image_data = prepare_image_for_dataset(optimized_image)
if debug:
total_time = time.time() - start_time
device = "GPU" if torch.cuda.is_available() else "CPU"
print(f" ✅ TrOCR ({device}) terminé en {total_time:.1f}s → '{final_result}'")
return final_result, optimized_image, dataset_image_data
except Exception as e:
print(f"❌ Erreur OCR TrOCR: {e}")
return "0", None, None
def recognize_number_fast(image_dict) -> tuple[str, any]:
"""Version rapide standard"""
result, optimized_image, _ = recognize_number_fast_with_image(image_dict)
return result, optimized_image
def recognize_number(image_dict) -> str:
"""Interface standard"""
result, _ = recognize_number_fast(image_dict)
return result
# Fonctions spécifiques au fine-tuning (pour plus tard)
def prepare_for_finetuning(dataset_path: str) -> dict:
"""Prépare le dataset pour le fine-tuning TrOCR"""
# TODO: Implémenter quand on aura HF Pro
return {"status": "ready_for_implementation"}
def quantize_model() -> bool:
"""Quantize le modèle TrOCR pour optimiser les performances CPU"""
# TODO: Implémenter la quantization
return False |