CalcTrainer / image_processing_gpu.py
hoololi's picture
Upload 5 files
5617612 verified
raw
history blame
5.44 kB
# ==========================================
# image_processing_gpu.py - Version GPU avec TrOCR
# ==========================================
"""
Module de traitement d'images GPU-optimisé pour calculs mathématiques
Utilise TrOCR pour une précision maximale sur GPU
"""
import time
import torch
from utils import (
optimize_image_for_ocr,
prepare_image_for_dataset,
create_thumbnail_fast,
create_white_canvas,
log_memory_usage,
cleanup_memory,
decode_image_from_dataset,
validate_ocr_result
)
# Variables globales pour OCR TrOCR
processor = None
model = None
OCR_MODEL_NAME = "TrOCR-base-handwritten"
def init_ocr_model() -> bool:
"""Initialise TrOCR (optimisé GPU)"""
global processor, model
try:
print("🔄 Chargement TrOCR (GPU optimisé)...")
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
# Optimisations GPU
model.eval()
if torch.cuda.is_available():
model = model.cuda()
device_info = f"GPU ({torch.cuda.get_device_name()})"
print(f"✅ TrOCR prêt sur {device_info} !")
else:
device_info = "CPU (pas de GPU détecté)"
print(f"⚠️ TrOCR sur CPU - {device_info}")
return True
except Exception as e:
print(f"❌ Erreur lors du chargement TrOCR: {e}")
return False
def get_ocr_model_info() -> dict:
"""Retourne les informations du modèle OCR utilisé"""
device = "GPU" if torch.cuda.is_available() and model is not None else "CPU"
gpu_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "N/A"
return {
"model_name": OCR_MODEL_NAME,
"device": device,
"gpu_name": gpu_name,
"framework": "HuggingFace-Transformers",
"optimized_for": "accuracy",
"version": "microsoft/trocr-base-handwritten"
}
def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
"""
OCR avec TrOCR (GPU optimisé)
Args:
image_dict: Image d'entrée (format Gradio)
debug: Afficher les logs de debug
Returns:
(résultat_ocr, image_optimisée, données_dataset)
"""
if image_dict is None or processor is None or model is None:
if debug:
print(" ❌ Image manquante ou TrOCR non initialisé")
return "0", None, None
try:
start_time = time.time()
if debug:
print(" 🔄 Début OCR TrOCR...")
# Optimiser image (fonction commune)
optimized_image = optimize_image_for_ocr(image_dict, max_size=384) # TrOCR préfère 384x384
if optimized_image is None:
if debug:
print(" ❌ Échec optimisation image")
return "0", None, None
# TrOCR - traitement spécialisé GPU
if debug:
print(" 🤖 Lancement TrOCR...")
with torch.no_grad():
# Preprocessing
pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values
# GPU transfer si disponible
if torch.cuda.is_available():
pixel_values = pixel_values.cuda()
# Génération optimisée
generated_ids = model.generate(
pixel_values,
max_length=4, # Optimisé pour les calculs
num_beams=1, # Rapide
do_sample=False, # Déterministe
early_stopping=True, # Arrêt rapide
pad_token_id=processor.tokenizer.pad_token_id
)
# Décodage
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
final_result = validate_ocr_result(result, max_length=4)
# Préparer pour dataset (fonction commune)
dataset_image_data = prepare_image_for_dataset(optimized_image)
if debug:
total_time = time.time() - start_time
device = "GPU" if torch.cuda.is_available() else "CPU"
print(f" ✅ TrOCR ({device}) terminé en {total_time:.1f}s → '{final_result}'")
return final_result, optimized_image, dataset_image_data
except Exception as e:
print(f"❌ Erreur OCR TrOCR: {e}")
return "0", None, None
def recognize_number_fast(image_dict) -> tuple[str, any]:
"""Version rapide standard"""
result, optimized_image, _ = recognize_number_fast_with_image(image_dict)
return result, optimized_image
def recognize_number(image_dict) -> str:
"""Interface standard"""
result, _ = recognize_number_fast(image_dict)
return result
# Fonctions spécifiques au fine-tuning (pour plus tard)
def prepare_for_finetuning(dataset_path: str) -> dict:
"""Prépare le dataset pour le fine-tuning TrOCR"""
# TODO: Implémenter quand on aura HF Pro
return {"status": "ready_for_implementation"}
def quantize_model() -> bool:
"""Quantize le modèle TrOCR pour optimiser les performances CPU"""
# TODO: Implémenter la quantization
return False