CalcTrainer / image_processing_gpu.py
hoololi's picture
Upload image_processing_gpu.py
1e4b8a0 verified
raw
history blame
5.68 kB
# ==========================================
# image_processing_gpu.py - Version ZeroGPU compatible
# ==========================================
"""
Module de traitement d'images GPU-optimisé pour calculs mathématiques
Compatible ZeroGPU HuggingFace Spaces
"""
import time
# Import spaces avec gestion d'erreur complète
try:
import spaces
print("✅ Import spaces réussi dans image_processing_gpu")
SPACES_AVAILABLE = True
except ImportError as e:
print(f"❌ Import spaces échoué: {e}")
# Créer un mock si spaces n'est pas disponible
class MockSpaces:
@staticmethod
def GPU(func):
print(f"MockSpaces.GPU décorateur appliqué à {func.__name__}")
return func
spaces = MockSpaces()
SPACES_AVAILABLE = False
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
print("❌ Torch non disponible")
TORCH_AVAILABLE = False
from utils import (
optimize_image_for_ocr,
prepare_image_for_dataset,
create_thumbnail_fast,
create_white_canvas,
log_memory_usage,
cleanup_memory,
decode_image_from_dataset,
validate_ocr_result
)
# Variables globales pour OCR
processor = None
model = None
OCR_MODEL_NAME = "TrOCR-base-handwritten"
def init_ocr_model() -> bool:
"""Initialise TrOCR (ZeroGPU compatible)"""
global processor, model
try:
print("🔄 Chargement TrOCR (ZeroGPU optimisé)...")
if not TORCH_AVAILABLE:
print("❌ Torch non disponible, impossible de charger TrOCR")
return False
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
# Optimisations
model.eval()
if torch.cuda.is_available():
model = model.cuda()
device_info = f"GPU ({torch.cuda.get_device_name()})"
print(f"✅ TrOCR prêt sur {device_info} !")
else:
device_info = "CPU (ZeroGPU pas encore alloué)"
print(f"⚠️ TrOCR sur CPU - {device_info}")
return True
except Exception as e:
print(f"❌ Erreur lors du chargement TrOCR: {e}")
return False
def get_ocr_model_info() -> dict:
"""Retourne les informations du modèle OCR utilisé"""
if TORCH_AVAILABLE and torch.cuda.is_available():
device = "ZeroGPU"
gpu_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "N/A"
else:
device = "CPU"
gpu_name = "N/A"
return {
"model_name": OCR_MODEL_NAME,
"device": device,
"gpu_name": gpu_name,
"framework": "HuggingFace-Transformers-ZeroGPU",
"optimized_for": "accuracy",
"version": "microsoft/trocr-base-handwritten"
}
@spaces.GPU # Décorateur ZeroGPU
def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]:
"""
OCR avec TrOCR (ZeroGPU optimisé)
"""
if image_dict is None:
if debug:
print(" ❌ Image manquante")
return "0", None, None
try:
start_time = time.time()
if debug:
print(" 🔄 Début OCR TrOCR ZeroGPU...")
# Optimiser image
optimized_image = optimize_image_for_ocr(image_dict, max_size=384)
if optimized_image is None:
if debug:
print(" ❌ Échec optimisation image")
return "0", None, None
# TrOCR - traitement ZeroGPU
if processor is None or model is None:
if debug:
print(" ❌ TrOCR non initialisé")
return "0", None, None
if debug:
print(" 🤖 Lancement TrOCR ZeroGPU...")
with torch.no_grad():
# Preprocessing
pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values
# GPU transfer si disponible
if torch.cuda.is_available():
pixel_values = pixel_values.cuda()
# Génération optimisée
generated_ids = model.generate(
pixel_values,
max_length=4,
num_beams=1,
do_sample=False,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id
)
# Décodage
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
final_result = validate_ocr_result(result, max_length=4)
# Préparer pour dataset
dataset_image_data = prepare_image_for_dataset(optimized_image)
if debug:
total_time = time.time() - start_time
device = "ZeroGPU" if torch.cuda.is_available() else "CPU"
print(f" ✅ TrOCR ({device}) terminé en {total_time:.1f}s → '{final_result}'")
return final_result, optimized_image, dataset_image_data
except Exception as e:
print(f"❌ Erreur OCR TrOCR ZeroGPU: {e}")
return "0", None, None
def recognize_number_fast(image_dict) -> tuple[str, any]:
"""Version rapide standard"""
result, optimized_image, _ = recognize_number_fast_with_image(image_dict)
return result, optimized_image
def recognize_number(image_dict) -> str:
"""Interface standard"""
result, _ = recognize_number_fast(image_dict)
return result