""" HandwrittenOCR - تدريب LoRA على TrOCR v4.0 ============================================== المحسنات: - global trocr_model (تصحيح #10) - commit_message مع التاريخ - تحديث تلقائي للنموذج في OCREngine """ import os import io import logging from PIL import Image from datetime import datetime logger = logging.getLogger("HandwrittenOCR") def finetune_trocr_lora( ocr_engine, db, save_path: str, min_samples: int = 50, epochs: int = 5, batch_size: int = 4, lr: float = 1e-5, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.1, lora_target_modules: list | None = None, ) -> bool: """ تدريب TrOCR باستخدام LoRA على البيانات الموثقة. بعد التدريب يُحدَّث ocr_engine.trocr_model تلقائياً. """ try: from peft import get_peft_model, LoraConfig, TaskType from torch.optim import AdamW from torch.utils.data import Dataset, DataLoader except ImportError: logger.error("peft غير مثبت") return False if lora_target_modules is None: lora_target_modules = ["query", "value"] # تصريح global لتحديث النموذج (تصحيح #10) global trocr_model trocr_model = ocr_engine.trocr_model trocr_processor = ocr_engine.trocr_processor device = ocr_engine.device # فحص العينات (يشمل verified و sentence_corrected) verified = db.get_verified() verified = [ w for w in verified if w.get("status") in ("verified", "sentence_corrected") ] if len(verified) < min_samples: logger.warning( f"عينات غير كافية: {len(verified)} < {min_samples}" ) return False print(f"بدء التدريب على {len(verified)} عينة...") # إعداد LoRA lora_config = LoraConfig( task_type=TaskType.SEQ_2_SEQ_LM, r=lora_r, lora_alpha=lora_alpha, target_modules=lora_target_modules, lora_dropout=lora_dropout, ) model = get_peft_model(trocr_model, lora_config) model.train() # Dataset class HandwritingDataset(Dataset): def __init__(self, records): self.records = records def __len__(self): return len(self.records) def __getitem__(self, idx): row = self.records[idx] img = Image.open(io.BytesIO(row["image_data"])).convert("RGB") pixel_values = trocr_processor( images=img, return_tensors="pt" ).pixel_values.squeeze() text = row["predicted_text"] or "" labels = trocr_processor.tokenizer( text, return_tensors="pt", padding="max_length", max_length=64, ).input_ids.squeeze() labels[labels == trocr_processor.tokenizer.pad_token_id] = -100 return {"pixel_values": pixel_values, "labels": labels} dataset = HandwritingDataset(verified) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = AdamW(model.parameters(), lr=lr) # التدريب for epoch in range(epochs): total_loss = 0 batch_count = 0 for batch in loader: out = model( pixel_values=batch["pixel_values"].to(device), labels=batch["labels"].to(device), ) out.loss.backward() optimizer.step() optimizer.zero_grad() total_loss += out.loss.item() batch_count += 1 avg_loss = total_loss / max(batch_count, 1) print(f"Epoch {epoch + 1}/{epochs} | Loss: {avg_loss:.4f}") # حفظ النموذج os.makedirs(save_path, exist_ok=True) model.save_pretrained(save_path) trocr_processor.save_pretrained(save_path) # تحديث النموذج في OCREngine تلقائياً ocr_engine.trocr_model = model ocr_engine.lora_loaded = True print(f"تم حفظ النموذج في: {save_path}") logger.info(f"تم تدريب LoRA وحفظه في: {save_path}") return True def check_auto_train(db, min_samples: int = 100) -> bool: """ فحص ما إذا كان عدد العينات المؤكدة كافياً للتدريب التلقائي. يُستدعى بعد كل عملية مراجعة لتشغيل التدريب تلقائياً. """ verified_count = db.get_verified_count() logger.debug(f"check_auto_train: {verified_count} عينة مؤكدة (مطلوب ≥{min_samples})") if verified_count >= min_samples: logger.info(f"العدد كافي ({verified_count} ≥ {min_samples}) — يمكن بدء التدريب التلقائي") return True return False