Spaces:
Sleeping
Sleeping
| # ========================================== | |
| # image_processing_gpu.py - Version GPU avec TrOCR | |
| # ========================================== | |
| """ | |
| Module de traitement d'images GPU-optimisé pour calculs mathématiques | |
| Utilise TrOCR pour une précision maximale sur GPU | |
| """ | |
| import time | |
| import torch | |
| from utils import ( | |
| optimize_image_for_ocr, | |
| prepare_image_for_dataset, | |
| create_thumbnail_fast, | |
| create_white_canvas, | |
| log_memory_usage, | |
| cleanup_memory, | |
| decode_image_from_dataset, | |
| validate_ocr_result | |
| ) | |
| # Variables globales pour OCR TrOCR | |
| processor = None | |
| model = None | |
| OCR_MODEL_NAME = "TrOCR-base-handwritten" | |
| def init_ocr_model() -> bool: | |
| """Initialise TrOCR (optimisé GPU)""" | |
| global processor, model | |
| try: | |
| print("🔄 Chargement TrOCR (GPU optimisé)...") | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
| model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten') | |
| # Optimisations GPU | |
| model.eval() | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| device_info = f"GPU ({torch.cuda.get_device_name()})" | |
| print(f"✅ TrOCR prêt sur {device_info} !") | |
| else: | |
| device_info = "CPU (pas de GPU détecté)" | |
| print(f"⚠️ TrOCR sur CPU - {device_info}") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Erreur lors du chargement TrOCR: {e}") | |
| return False | |
| def get_ocr_model_info() -> dict: | |
| """Retourne les informations du modèle OCR utilisé""" | |
| device = "GPU" if torch.cuda.is_available() and model is not None else "CPU" | |
| gpu_name = torch.cuda.get_device_name() if torch.cuda.is_available() else "N/A" | |
| return { | |
| "model_name": OCR_MODEL_NAME, | |
| "device": device, | |
| "gpu_name": gpu_name, | |
| "framework": "HuggingFace-Transformers", | |
| "optimized_for": "accuracy", | |
| "version": "microsoft/trocr-base-handwritten" | |
| } | |
| def recognize_number_fast_with_image(image_dict, debug: bool = False) -> tuple[str, any, dict | None]: | |
| """ | |
| OCR avec TrOCR (GPU optimisé) | |
| Args: | |
| image_dict: Image d'entrée (format Gradio) | |
| debug: Afficher les logs de debug | |
| Returns: | |
| (résultat_ocr, image_optimisée, données_dataset) | |
| """ | |
| if image_dict is None or processor is None or model is None: | |
| if debug: | |
| print(" ❌ Image manquante ou TrOCR non initialisé") | |
| return "0", None, None | |
| try: | |
| start_time = time.time() | |
| if debug: | |
| print(" 🔄 Début OCR TrOCR...") | |
| # Optimiser image (fonction commune) | |
| optimized_image = optimize_image_for_ocr(image_dict, max_size=384) # TrOCR préfère 384x384 | |
| if optimized_image is None: | |
| if debug: | |
| print(" ❌ Échec optimisation image") | |
| return "0", None, None | |
| # TrOCR - traitement spécialisé GPU | |
| if debug: | |
| print(" 🤖 Lancement TrOCR...") | |
| with torch.no_grad(): | |
| # Preprocessing | |
| pixel_values = processor(images=optimized_image, return_tensors="pt").pixel_values | |
| # GPU transfer si disponible | |
| if torch.cuda.is_available(): | |
| pixel_values = pixel_values.cuda() | |
| # Génération optimisée | |
| generated_ids = model.generate( | |
| pixel_values, | |
| max_length=4, # Optimisé pour les calculs | |
| num_beams=1, # Rapide | |
| do_sample=False, # Déterministe | |
| early_stopping=True, # Arrêt rapide | |
| pad_token_id=processor.tokenizer.pad_token_id | |
| ) | |
| # Décodage | |
| result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| final_result = validate_ocr_result(result, max_length=4) | |
| # Préparer pour dataset (fonction commune) | |
| dataset_image_data = prepare_image_for_dataset(optimized_image) | |
| if debug: | |
| total_time = time.time() - start_time | |
| device = "GPU" if torch.cuda.is_available() else "CPU" | |
| print(f" ✅ TrOCR ({device}) terminé en {total_time:.1f}s → '{final_result}'") | |
| return final_result, optimized_image, dataset_image_data | |
| except Exception as e: | |
| print(f"❌ Erreur OCR TrOCR: {e}") | |
| return "0", None, None | |
| def recognize_number_fast(image_dict) -> tuple[str, any]: | |
| """Version rapide standard""" | |
| result, optimized_image, _ = recognize_number_fast_with_image(image_dict) | |
| return result, optimized_image | |
| def recognize_number(image_dict) -> str: | |
| """Interface standard""" | |
| result, _ = recognize_number_fast(image_dict) | |
| return result | |
| # Fonctions spécifiques au fine-tuning (pour plus tard) | |
| def prepare_for_finetuning(dataset_path: str) -> dict: | |
| """Prépare le dataset pour le fine-tuning TrOCR""" | |
| # TODO: Implémenter quand on aura HF Pro | |
| return {"status": "ready_for_implementation"} | |
| def quantize_model() -> bool: | |
| """Quantize le modèle TrOCR pour optimiser les performances CPU""" | |
| # TODO: Implémenter la quantization | |
| return False |