satvikjain commited on
Commit
1024113
·
0 Parent(s):

initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +36 -0
  2. Dockerfile +25 -0
  3. README.md +37 -0
  4. app.py +32 -0
  5. inference.py +201 -0
  6. model_final.pth +3 -0
  7. requirements.txt +16 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .pth filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+ ENV PIP_NO_CACHE_DIR=1
5
+
6
+ COPY requirements.txt /app/requirements.txt
7
+
8
+ RUN apt-get update && apt-get install -y \
9
+ git build-essential ffmpeg libsm6 libxext6 curl \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ RUN pip install --upgrade pip && pip install -r requirements.txt
13
+
14
+ # Detectron2 CPU build (from source) - may take time on first build
15
+ RUN pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.5#egg=detectron2'
16
+
17
+ COPY app.py /app/app.py
18
+ COPY inference.py /app/inference.py
19
+ COPY model_final.pth /app/model_final.pth
20
+
21
+ ENV USE_GPU=false
22
+ EXPOSE 7860
23
+ CMD ["python", "app.py"]
24
+
25
+
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## PDF OCR (Detectron2 + TrOCR) - Hugging Face Spaces
2
+
3
+ This repo contains a deployable Gradio app that detects text lines with Detectron2 and reads them with TrOCR. Optional Gemini correction can refine the text.
4
+
5
+ ### Files
6
+ - `app.py`: Gradio UI
7
+ - `inference.py`: OCR pipeline (Detectron2 + TrOCR)
8
+ - `requirements.txt`: Python dependencies (Detectron2 installed in Dockerfile)
9
+ - `Dockerfile`: CUDA-enabled image for GPU Space
10
+ - `model_final.pth`: Detectron2 weights
11
+
12
+ ### Deploy on Hugging Face Spaces (Docker Space)
13
+ 1. Create a new Space on Hugging Face → Type: Docker → Hardware: GPU (T4/A10G).
14
+ 2. Push these files to the Space repository (or connect this folder and `git push`).
15
+ 3. Set optional secret: `GEMINI_API_KEY` (for correction) in Space Settings → Secrets.
16
+ 4. Wait for the build to finish. The app will start on port 7860.
17
+
18
+ ### Use
19
+ 1. Upload a PDF.
20
+ 2. (Optional) Toggle Split-page (currently standard pipeline is used) and Gemini correction.
21
+ 3. Click Process.
22
+ 4. Download the ZIP of per-page JSONs. The full combined text is shown in the textbox.
23
+
24
+ ### Local run (GPU recommended)
25
+ ```bash
26
+ docker build -t ocr-app .
27
+ docker run --gpus all -p 7860:7860 ocr-app
28
+ ```
29
+
30
+ Then open http://localhost:7860
31
+
32
+ ### Notes
33
+ - Detectron2 requires GPU for reasonable speed; CPU will be slow.
34
+ - `TEXTLINE_MODEL_PATH` can be overridden via env var if the weights are elsewhere.
35
+ - TrOCR models are downloaded on first run and cached in the container layer after warmup.
36
+
37
+
app.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from inference import run_ocr
4
+
5
+
6
+ def predict(pdf_file, split_page, use_llm, gemini_key):
7
+ if pdf_file is None:
8
+ return "Please upload a PDF.", None
9
+ key = gemini_key or os.getenv("GEMINI_API_KEY", None)
10
+ text, zip_path = run_ocr(pdf_file.name, split_page_enabled=split_page, use_llm=use_llm, gemini_key=key)
11
+ return text, zip_path
12
+
13
+
14
+ with gr.Blocks() as demo:
15
+ gr.Markdown("## PDF OCR (Detectron2 + TrOCR)")
16
+ with gr.Row():
17
+ with gr.Column():
18
+ pdf = gr.File(label="Upload PDF", file_types=[".pdf"])
19
+ split_page = gr.Checkbox(label="Split-page mode", value=False)
20
+ use_llm = gr.Checkbox(label="Gemini correction", value=False)
21
+ gemini_key = gr.Textbox(label="Gemini API Key (optional)", type="password")
22
+ btn = gr.Button("Process")
23
+ with gr.Column():
24
+ out_text = gr.Textbox(label="Extracted Text", lines=18)
25
+ out_zip = gr.File(label="Per-page JSON (ZIP)")
26
+ btn.click(predict, inputs=[pdf, split_page, use_llm, gemini_key], outputs=[out_text, out_zip])
27
+
28
+
29
+ if __name__ == "__main__":
30
+ demo.launch(server_name="0.0.0.0", server_port=7860)
31
+
32
+
inference.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import json
4
+ import time
5
+ import shutil
6
+ import tempfile
7
+ from typing import Tuple
8
+
9
+ import cv2
10
+ import fitz # PyMuPDF
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+ import torch
15
+ from detectron2.config import get_cfg
16
+ from detectron2.engine import DefaultPredictor
17
+ from detectron2.data import MetadataCatalog
18
+ from detectron2 import model_zoo
19
+
20
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
21
+
22
+
23
+ # -----------------------------
24
+ # Configuration (override via env if needed)
25
+ # -----------------------------
26
+ TEXTLINE_MODEL_PATH = os.getenv("TEXTLINE_MODEL_PATH", "./model_final.pth")
27
+ USE_GPU = os.getenv("USE_GPU", "true").lower() == "true"
28
+ SCORE_THRESHOLD = float(os.getenv("SCORE_THRESHOLD", "0.5"))
29
+ AREA_THRESHOLD_PERCENT = float(os.getenv("AREA_THRESHOLD_PERCENT", "12.5"))
30
+ DPI = int(os.getenv("PDF_DPI", "200"))
31
+
32
+ TROCR_SPANISH_MODEL = os.getenv("TROCR_SPANISH_MODEL", "qantev/trocr-large-spanish")
33
+ TROCR_FALLBACK_MODEL = os.getenv("TROCR_FALLBACK_MODEL", "microsoft/trocr-base-printed")
34
+
35
+
36
+ class EnhancedTextlineExtractor:
37
+ def __init__(self, model_path: str):
38
+ self.cfg = self._setup_cfg(model_path)
39
+ self.predictor = DefaultPredictor(self.cfg)
40
+
41
+ # Init TrOCR
42
+ self.device = torch.device("cuda" if torch.cuda.is_available() and USE_GPU else "cpu")
43
+ self.trocr_processor, self.trocr_model = self._load_trocr()
44
+ self.trocr_model.to(self.device)
45
+
46
+ def _setup_cfg(self, model_path: str):
47
+ cfg = get_cfg()
48
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
49
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # textline, baseline
50
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = SCORE_THRESHOLD
51
+ cfg.MODEL.WEIGHTS = model_path
52
+ cfg.DATASETS.TEST = ("page_test",)
53
+ cfg.DATALOADER.NUM_WORKERS = 2
54
+ MetadataCatalog.get("page_test").thing_classes = ["textline", "baseline"]
55
+ return cfg
56
+
57
+ def _load_trocr(self):
58
+ try:
59
+ processor = TrOCRProcessor.from_pretrained(TROCR_SPANISH_MODEL)
60
+ model = VisionEncoderDecoderModel.from_pretrained(TROCR_SPANISH_MODEL)
61
+ return processor, model
62
+ except Exception:
63
+ processor = TrOCRProcessor.from_pretrained(TROCR_FALLBACK_MODEL)
64
+ model = VisionEncoderDecoderModel.from_pretrained(TROCR_FALLBACK_MODEL)
65
+ return processor, model
66
+
67
+ def pdf_to_images(self, pdf_path: str, dpi: int = DPI):
68
+ doc = fitz.open(pdf_path)
69
+ images = []
70
+ try:
71
+ for page_num in range(len(doc)):
72
+ page = doc.load_page(page_num)
73
+ mat = fitz.Matrix(dpi / 72, dpi / 72)
74
+ pix = page.get_pixmap(matrix=mat)
75
+ img_data = pix.tobytes("png")
76
+ nparr = np.frombuffer(img_data, np.uint8)
77
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
78
+ images.append(img)
79
+ finally:
80
+ doc.close()
81
+ return images
82
+
83
+ def filter_margin_boxes_by_area(self, boxes, scores, area_threshold_percent: float = AREA_THRESHOLD_PERCENT):
84
+ if len(boxes) == 0:
85
+ return np.array([]), np.array([]), np.array([]), np.array([])
86
+ areas = []
87
+ for box in boxes:
88
+ x1, y1, x2, y2 = box
89
+ areas.append((x2 - x1) * (y2 - y1))
90
+ areas = np.array(areas)
91
+ avg_area = np.mean(areas)
92
+ area_threshold = avg_area * (area_threshold_percent / 100.0)
93
+ main_boxes, main_scores, margin_boxes, margin_scores = [], [], [], []
94
+ for b, s, a in zip(boxes, scores, areas):
95
+ if a >= area_threshold:
96
+ main_boxes.append(b)
97
+ main_scores.append(s)
98
+ else:
99
+ margin_boxes.append(b)
100
+ margin_scores.append(s)
101
+ return np.array(main_boxes), np.array(main_scores), np.array(margin_boxes), np.array(margin_scores)
102
+
103
+ def process_page_standard(self, image):
104
+ outputs = self.predictor(image)
105
+ instances = outputs["instances"]
106
+ boxes = instances.pred_boxes.tensor.cpu().numpy()
107
+ scores = instances.scores.cpu().numpy()
108
+ if len(boxes) == 0:
109
+ return {"success": False, "error": "No textlines detected"}
110
+ main_boxes, main_scores, _, _ = self.filter_margin_boxes_by_area(boxes, scores)
111
+ if len(main_boxes) == 0:
112
+ return {"success": False, "error": "No textlines after filtering"}
113
+
114
+ line_segments = []
115
+ full_text_lines = []
116
+ for i, (box, score) in enumerate(zip(main_boxes, main_scores)):
117
+ x1, y1, x2, y2 = map(int, box)
118
+ crop_bgr = image[y1:y2, x1:x2]
119
+ try:
120
+ crop_rgb = cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB)
121
+ pil_image = Image.fromarray(crop_rgb)
122
+ pixel_values = self.trocr_processor(images=pil_image, return_tensors="pt").pixel_values
123
+ pixel_values = pixel_values.to(self.device)
124
+ with torch.no_grad():
125
+ generated_ids = self.trocr_model.generate(pixel_values, max_new_tokens=128)
126
+ generated_text = self.trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
127
+ text = generated_text.strip()
128
+ full_text_lines.append(text)
129
+ line_segments.append({
130
+ "line_index": i,
131
+ "bbox": [int(x1), int(y1), int(x2), int(y2)],
132
+ "score": float(score),
133
+ "text": text,
134
+ "confidence": 1.0
135
+ })
136
+ except Exception:
137
+ line_segments.append({
138
+ "line_index": i,
139
+ "bbox": [int(x1), int(y1), int(x2), int(y2)],
140
+ "score": float(score),
141
+ "text": "",
142
+ "confidence": 0.0
143
+ })
144
+ return {
145
+ "success": True,
146
+ "line_segments": line_segments,
147
+ "full_text": "\n".join(full_text_lines)
148
+ }
149
+
150
+
151
+ def _zip_directory(src_dir: str, zip_path: str) -> str:
152
+ base, _ = os.path.splitext(zip_path)
153
+ archive = shutil.make_archive(base, 'zip', src_dir)
154
+ return archive
155
+
156
+
157
+ def run_ocr(pdf_path: str, split_page_enabled: bool = False, use_llm: bool = False, gemini_key: str = None) -> Tuple[str, str]:
158
+ """
159
+ Run OCR on the provided PDF.
160
+
161
+ Returns:
162
+ combined_text (str), zip_file_path (str)
163
+ """
164
+ extractor = EnhancedTextlineExtractor(TEXTLINE_MODEL_PATH)
165
+ images = extractor.pdf_to_images(pdf_path, dpi=DPI)
166
+
167
+ temp_dir = tempfile.mkdtemp(prefix="ocr_outputs_")
168
+ inferences_dir = os.path.join(temp_dir, "inferences")
169
+ os.makedirs(inferences_dir, exist_ok=True)
170
+
171
+ all_results = []
172
+ for i, image in enumerate(images):
173
+ result = extractor.process_page_standard(image)
174
+ all_results.append(result)
175
+ page_file = os.path.join(inferences_dir, f"page_{i+1}_result.json")
176
+ with open(page_file, "w", encoding="utf-8") as f:
177
+ json.dump(result, f, ensure_ascii=False, indent=2)
178
+
179
+ combined_text = "\n\n".join([r.get("full_text", "") for r in all_results if r.get("success")])
180
+
181
+ # Optional Gemini correction over combined text (simple, single pass)
182
+ if use_llm and gemini_key and combined_text.strip():
183
+ try:
184
+ import google.generativeai as genai
185
+ genai.configure(api_key=gemini_key)
186
+ prompt = (
187
+ "Correct the following historical Spanish OCR text while preserving grammar and style. "
188
+ "Fix orthography, punctuation, and obvious OCR mistakes. Return only corrected text.\n\n" + combined_text
189
+ )
190
+ response = genai.GenerativeModel('gemini-2.5-pro').generate_content(prompt)
191
+ if getattr(response, 'text', None):
192
+ combined_text = response.text.strip()
193
+ except Exception:
194
+ # Swallow LLM errors and return original text
195
+ pass
196
+
197
+ zip_path = os.path.join(temp_dir, "per_page_jsons.zip")
198
+ archive_path = _zip_directory(inferences_dir, zip_path)
199
+ return combined_text, archive_path
200
+
201
+
model_final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:466b008d261466deff5ab6f0517403441bcf7e379f39a715709b78503d252158
3
+ size 503106240
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ transformers==4.44.2
4
+ opencv-python-headless==4.11.0.86
5
+ pillow==11.2.1
6
+ layoutparser==0.3.4
7
+ pdfplumber==0.11.7
8
+ pymupdf==1.24.10
9
+ numpy==1.26.4
10
+ scipy==1.15.3
11
+ pandas==2.2.3
12
+ google-generativeai==0.7.2
13
+ python-dotenv==1.0.1
14
+ gradio==4.44.0
15
+ # detectron2 installed via Dockerfile to match CUDA
16
+