FiberGate / scripts /batch_process_dataref.py
AzizMiladi's picture
refactor: replace importlib hacks with normal package imports
5647d1a
Raw
History Blame
4.66 kB
"""
Batch process all documents in DataRef folder using subprocess.
Calls `python -m guichetoi.inference` on each image to avoid import issues.
"""
import json
import logging
import subprocess
from pathlib import Path
from collections import defaultdict
import sys
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)-7s %(message)s")
log = logging.getLogger("batch_process")
def main():
dataref_dir = Path("DataRef")
if not dataref_dir.exists():
log.error(f"DataRef directory not found: {dataref_dir}")
return
# Find all image/PDF files
image_extensions = {".png", ".jpg", ".jpeg", ".pdf", ".bmp", ".tif", ".tiff"}
files = [f for f in dataref_dir.rglob("*") if f.suffix.lower() in image_extensions]
log.info(f"Found {len(files)} document(s) in DataRef")
results = []
stats = defaultdict(int)
# destination for per-document JSON results from this batch
processed_dir = Path("processed_dataref")
processed_dir.mkdir(parents=True, exist_ok=True)
for i, file_path in enumerate(sorted(files), 1):
rel_path = file_path.relative_to(dataref_dir)
log.info(f"[{i}/{len(files)}] Processing: {rel_path}")
try:
# Call inference CLI via subprocess (`pip install -e .` required)
cmd = ["python", "-m", "guichetoi.inference", "--image", str(file_path), "--device", "cpu"]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
log.error(f" ERROR: CLI returned code {result.returncode}: {result.stderr[:200]}")
stats["errors"] += 1
continue
# Read JSON output from outputs/{filename}_result.json
try:
result_file = Path("outputs") / f"{file_path.stem}_result.json"
if not result_file.exists():
log.error(f" ERROR: Output file not created: {result_file}")
stats["errors"] += 1
continue
# move the per-document JSON into the processed_dataref folder
dest_file = processed_dir / result_file.name
try:
result_file.replace(dest_file)
except Exception:
import shutil
shutil.copy(result_file, dest_file)
try:
result_file.unlink()
except Exception:
pass
with open(dest_file, "r", encoding="utf-8") as f:
output_data = json.load(f)
results.append(output_data)
stats["total"] += 1
if "doc_class" in output_data:
stats[f"class_{output_data['doc_class']}"] += 1
if output_data.get("fields"):
stats["with_fields"] += 1
# Log key fields
fields = output_data.get("fields", {})
log_fields = ["Reference_Urbanisme", "DLPI", "cabinet_conseil", "nb_log_totale", "Nb_log_pro", "Nb_log_res"]
extracted = [f for f in log_fields if f in fields]
if extracted:
field_strs = [f"{f}={fields[f].get('value', '?')}" for f in extracted]
log.info(f" → Extracted: {', '.join(field_strs)}")
except json.JSONDecodeError as e:
log.error(f" ERROR: Failed to parse JSON output: {e}")
stats["errors"] += 1
except subprocess.TimeoutExpired:
log.error(f" ERROR: Processing timed out (>120s)")
stats["errors"] += 1
except Exception as e:
log.error(f" ERROR: {e}")
stats["errors"] += 1
# Save batch results into processed_dataref
output_file = processed_dir / "batch_dataref_results.json"
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"total_processed": len(results),
"statistics": dict(stats),
"results": results
}, f, ensure_ascii=False, indent=2)
log.info(f"\n{'='*60}")
log.info(f"Batch processing complete!")
log.info(f" Total: {stats['total']}")
log.info(f" With fields extracted: {stats['with_fields']}")
log.info(f" Errors: {stats['errors']}")
log.info(f" Results saved to: {output_file}")
log.info(f"{'='*60}")
if __name__ == "__main__":
main()