newspaper-api / extractor.py
gohilnath2's picture
Priority queue, single SambaNova provider, sequential page processing
15f3011
# =============================================================================
# 📰 Newspaper Article Extractor — Core Pipeline
# No UI dependencies. Can be used standalone:
# from extractor import ExtractionPipeline
# pipeline = ExtractionPipeline(api_key="...")
# result = pipeline.extract(pdf_path, page_num=0)
# =============================================================================
import json
import time
import re
import base64
import fitz
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
from huggingface_hub import snapshot_download
from openai import OpenAI
from difflib import SequenceMatcher
import io
import os
import logging
from config import (
DPI, SCALE_FACTOR, LLM_BASE_URL, VISION_MODEL,
YOLO_REPO, YOLO_CONF_THRESHOLD,
SKIP_CLASSES, HEADLINE_CLASSES, BODY_CLASSES,
BBOX_PADDING, HEADLINE_DISTANCE_FACTOR, HEADLINE_MIN_DISTANCE,
GROUPING_PROMPT,
)
logger = logging.getLogger("newspaper_extractor")
# EasyOCR — lazy loaded only for scanned PDFs
_ocr_reader = None
def _get_ocr_reader():
global _ocr_reader
if _ocr_reader is None:
import easyocr
logger.info("Loading EasyOCR fallback...")
_ocr_reader = easyocr.Reader(['en'], gpu=False)
return _ocr_reader
# =============================================================================
# Pipeline class
# =============================================================================
class ExtractionPipeline:
"""Main extraction pipeline. Initialize once, call extract() per page."""
def __init__(self, api_key, cache_dir="/tmp/hf_cache"):
# Load YOLO model
logger.info("Loading YOLO model...")
repo_path = snapshot_download(YOLO_REPO, cache_dir=cache_dir)
self.yolo_model = YOLO(os.path.join(repo_path, "weights/best.pt"))
logger.info(f"YOLO classes: {self.yolo_model.names}")
# LLM client
self.llm_client = OpenAI(base_url=LLM_BASE_URL, api_key=api_key)
logger.info("✅ Pipeline initialized")
# -----------------------------------------------------------------
# Public API
# -----------------------------------------------------------------
def extract(self, pdf_path, page_num=0):
"""
Extract articles from a single PDF page.
Returns: (result_dict, viz_image, regions, is_digital)
"""
is_digital = self._is_digital_pdf(pdf_path, page_num)
image, total_pages = self._pdf_page_to_image(pdf_path, page_num)
if image is None:
return None, None, None, is_digital, total_pages
regions = self._detect_layout(image)
viz_image = self._visualize_layout(image, regions)
img_b64 = self._create_numbered_image(image, regions)
self._extract_region_texts(regions, pdf_path, page_num, image, is_digital)
grouping = self._group_regions(img_b64, regions)
result = self._assemble_articles(grouping, regions, pdf_path, page_num, image, is_digital)
return result, viz_image, regions, is_digital, total_pages
def get_page_count(self, pdf_path):
"""Return total page count of a PDF."""
doc = fitz.open(pdf_path)
count = doc.page_count
doc.close()
return count
# -----------------------------------------------------------------
# PDF helpers
# -----------------------------------------------------------------
def _is_digital_pdf(self, pdf_path, page_num=0):
doc = fitz.open(pdf_path)
page = doc[page_num]
text = page.get_text("text").strip()
doc.close()
is_digital = len(text) > 500
logger.info(f"PDF type: {'Digital' if is_digital else 'Scanned'} ({len(text)} chars)")
return is_digital
def _pdf_page_to_image(self, pdf_path, page_num):
doc = fitz.open(pdf_path)
if page_num < 0 or page_num >= doc.page_count:
total = doc.page_count
doc.close()
return None, total
page = doc[page_num]
mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR)
pix = page.get_pixmap(matrix=mat)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
total = doc.page_count
doc.close()
return img, total
@staticmethod
def _pixel_bbox_to_pdf_rect(bbox):
x1, y1, x2, y2 = bbox
return fitz.Rect(
x1 / SCALE_FACTOR, y1 / SCALE_FACTOR,
x2 / SCALE_FACTOR, y2 / SCALE_FACTOR,
)
# -----------------------------------------------------------------
# Layout detection
# -----------------------------------------------------------------
def _detect_layout(self, image):
results = self.yolo_model.predict(
source=image, conf=YOLO_CONF_THRESHOLD, imgsz=1024, verbose=False,
)
regions = []
for result in results:
boxes = result.boxes
for i in range(len(boxes)):
x1, y1, x2, y2 = boxes.xyxy[i].tolist()
conf = boxes.conf[i].item()
cls_id = int(boxes.cls[i].item())
cls_name = result.names[cls_id]
regions.append({
"bbox": [int(x1), int(y1), int(x2), int(y2)],
"class": cls_name,
"confidence": round(conf, 3),
})
regions.sort(key=lambda r: (r["bbox"][1], r["bbox"][0]))
logger.info(f"Detected {len(regions)} regions")
return regions
# -----------------------------------------------------------------
# Visualization
# -----------------------------------------------------------------
@staticmethod
def _visualize_layout(image, regions):
img_copy = image.copy()
draw = ImageDraw.Draw(img_copy)
colors = {
"title": "#E24B4A", "text": "#378ADD", "picture": "#639922",
"figure": "#639922", "table": "#BA7517", "caption": "#1D9E75",
"section-header": "#E24B4A", "header": "#888780", "footer": "#888780",
}
for i, r in enumerate(regions):
x1, y1, x2, y2 = r["bbox"]
color = colors.get(r["class"].lower(), "#888780")
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
label = f"[{i}] {r['class']}"
draw.rectangle([x1, y1, x1 + len(label) * 7, y1 + 16], fill=color)
draw.text((x1 + 2, y1 + 1), label, fill="white")
return img_copy
@staticmethod
def _create_numbered_image(image, regions):
img_copy = image.copy()
draw = ImageDraw.Draw(img_copy)
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28,
)
except OSError:
font = ImageFont.load_default()
for i, r in enumerate(regions):
x1, y1, x2, y2 = r["bbox"]
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
draw.rectangle([x1, y1, x1 + 32, y1 + 32], fill="red")
draw.text((x1 + 6, y1 + 2), str(i), fill="white", font=font)
img_copy.thumbnail((1000, 1000))
buf = io.BytesIO()
img_copy.save(buf, format="JPEG", quality=55)
return base64.b64encode(buf.getvalue()).decode()
# -----------------------------------------------------------------
# Text extraction
# -----------------------------------------------------------------
def _extract_region_text_pymupdf(self, pdf_path, page_num, bbox):
doc = fitz.open(pdf_path)
page = doc[page_num]
text = page.get_text("text", clip=self._pixel_bbox_to_pdf_rect(bbox)).strip()
doc.close()
return text
@staticmethod
def _extract_region_text_ocr(image, bbox, region_class):
reader = _get_ocr_reader()
x1, y1, x2, y2 = bbox
pad = 5
x1, y1 = max(0, x1 - pad), max(0, y1 - pad)
x2, y2 = min(image.width, x2 + pad), min(image.height, y2 + pad)
crop_np = np.array(image.crop((x1, y1, x2, y2)))
threshold = 0.5 if region_class.lower() in HEADLINE_CLASSES else 0.4
results = reader.readtext(crop_np, paragraph=True, text_threshold=threshold)
return " ".join([r[1] for r in results]).strip()
def _extract_region_texts(self, regions, pdf_path, page_num, image, is_digital):
"""Extract text for all regions (used for LLM summary)."""
for r in regions:
if r["class"].lower() in SKIP_CLASSES:
r["text"] = ""
r["text_source"] = "skipped"
continue
if is_digital:
text = self._extract_region_text_pymupdf(pdf_path, page_num, r["bbox"])
if len(text) > 3:
r["text"] = text
r["text_source"] = "pymupdf"
continue
r["text"] = self._extract_region_text_ocr(image, r["bbox"], r["class"])
r["text_source"] = "ocr"
def _extract_article_body_bbox(self, pdf_path, page_num, body_idxs, regions,
headline_bbox=None):
"""Extract body text from bounding box in PDF content stream order."""
if not body_idxs:
return ""
valid = body_idxs
# Headline-based horizontal constraint
if headline_bbox:
h_x1, _, h_x2, _ = headline_bbox
h_center = (h_x1 + h_x2) / 2
h_width = h_x2 - h_x1
max_dist = max(h_width * HEADLINE_DISTANCE_FACTOR, HEADLINE_MIN_DISTANCE)
filtered = [
i for i in valid
if abs((regions[i]["bbox"][0] + regions[i]["bbox"][2]) / 2 - h_center)
<= max_dist
]
if filtered:
valid = filtered
x1 = min(regions[i]["bbox"][0] for i in valid) + abs(BBOX_PADDING)
y1 = min(regions[i]["bbox"][1] for i in valid) + abs(BBOX_PADDING)
x2 = max(regions[i]["bbox"][2] for i in valid) - abs(BBOX_PADDING)
y2 = max(regions[i]["bbox"][3] for i in valid) - abs(BBOX_PADDING)
pdf_rect = fitz.Rect(
x1 / SCALE_FACTOR, y1 / SCALE_FACTOR,
x2 / SCALE_FACTOR, y2 / SCALE_FACTOR,
)
doc = fitz.open(pdf_path)
page = doc[page_num]
text = page.get_text("text", clip=pdf_rect).strip()
doc.close()
return text
# -----------------------------------------------------------------
# Vision LLM
# -----------------------------------------------------------------
def _call_vision_llm(self, img_b64, prompt, max_retries=3):
for attempt in range(max_retries):
try:
resp = self.llm_client.chat.completions.create(
model=VISION_MODEL,
messages=[{
"role": "user",
"content": [
{"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}},
{"type": "text", "text": prompt},
],
}],
temperature=0.1,
max_tokens=2048,
)
return resp.choices[0].message.content
except Exception as e:
if "429" in str(e) or "rate" in str(e).lower():
wait = 60
m = re.search(r"(\d+\.?\d*)\s*s", str(e))
if m:
wait = float(m.group(1)) + 2
logger.warning(f"Rate limited, waiting {wait:.0f}s (attempt {attempt + 1})")
time.sleep(wait)
continue
raise
raise RuntimeError("Vision LLM failed after retries")
def _group_regions(self, img_b64, regions):
lines = []
for i, r in enumerate(regions):
text = r.get("text", "")
if not text:
lines.append(f"[{i}] {r['class']} — (no text / image)")
else:
preview = text[:80].replace("\n", " ")
lines.append(f"[{i}] {r['class']} — \"{preview}\"")
prompt = GROUPING_PROMPT.format(region_summary="\n".join(lines))
raw = self._call_vision_llm(img_b64, prompt).strip()
if raw.startswith("```"):
raw = raw.split("\n", 1)[1].rsplit("```", 1)[0]
grouping = json.loads(raw)
# Log orphans (informational — bounding box compensates)
assigned = set()
for art in grouping.get("articles", []):
if art.get("headline_region") is not None:
assigned.add(art["headline_region"])
assigned.update(art.get("body_regions", []))
assigned.update(grouping.get("discarded_regions", []))
orphaned_text = [
i for i in set(range(len(regions))) - assigned
if regions[i].get("text")
]
if orphaned_text:
logger.info(
f"{len(orphaned_text)} unassigned text regions "
"(bounding box will capture them)"
)
logger.info(f"Grouped into {len(grouping.get('articles', []))} articles")
return grouping
# -----------------------------------------------------------------
# Text formatting
# -----------------------------------------------------------------
@staticmethod
def _format_body_text(raw_text):
"""
Convert PyMuPDF raw output into clean paragraphed text.
PyMuPDF returns text with line breaks at every visual line end
in the PDF column. This function:
1. Rejoins hyphenated words split across lines
2. Joins lines within the same paragraph
3. Detects paragraph breaks (sentence end + next line starts uppercase)
"""
if not raw_text:
return ""
lines = raw_text.split("\n")
paragraphs = []
current = []
for line in lines:
line = line.rstrip()
# Skip empty lines
if not line:
if current:
paragraphs.append(" ".join(current))
current = []
continue
# If current paragraph buffer has content, check for paragraph break
if current:
last = current[-1]
# Rejoin hyphenated word: "ap-\npointed" → "appointed"
if last.endswith("-"):
current[-1] = last[:-1]
current.append(line)
continue
# Paragraph break: previous line ends sentence + new line starts uppercase
ends_sentence = last.rstrip().endswith((".", '"', "'", "?", "!"))
starts_upper = line.lstrip()[:1].isupper()
if ends_sentence and starts_upper:
paragraphs.append(" ".join(current))
current = [line]
continue
current.append(line)
# Flush remaining
if current:
paragraphs.append(" ".join(current))
# Clean up each paragraph
cleaned = []
for para in paragraphs:
# Collapse multiple spaces
para = re.sub(r"\s{2,}", " ", para).strip()
if para:
cleaned.append(para)
return "\n\n".join(cleaned)
# -----------------------------------------------------------------
# Assembly
# -----------------------------------------------------------------
@staticmethod
def _fuzzy_match(a, b, threshold=0.8):
if not a or not b:
return False
shorter, longer = (a, b) if len(a) < len(b) else (b, a)
if shorter in longer:
return True
if len(shorter) / len(longer) < 0.5:
return False
return SequenceMatcher(None, a[:200], b[:200]).ratio() > threshold
def _assemble_articles(self, grouping, regions, pdf_path, page_num,
image, is_digital):
articles = []
for group in grouping.get("articles", []):
headline_idx = group.get("headline_region")
body_idxs = group.get("body_regions", [])
category = group.get("category")
# Headline
headline = ""
if headline_idx is not None and headline_idx < len(regions):
headline = " ".join(regions[headline_idx].get("text", "").split())
# Classify body regions by role
text_body_idxs = []
subheadline_idxs = []
byline_idxs = []
caption_idxs = []
continuation_idxs = []
for idx in body_idxs:
if idx >= len(regions):
continue
r = regions[idx]
text = r.get("text", "").strip()
cls = r["class"].lower()
if not text:
continue
if cls in HEADLINE_CLASSES:
subheadline_idxs.append(idx)
elif re.search(r"CONTINUED\s+ON", text, re.IGNORECASE):
continuation_idxs.append(idx)
elif re.match(
r"^[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,3}\s*$",
text.split("\n")[0],
):
byline_idxs.append(idx)
elif cls == "caption":
caption_idxs.append(idx)
else:
text_body_idxs.append(idx)
# --- Body text ---
headline_bbox = (
regions[headline_idx]["bbox"]
if headline_idx is not None and headline_idx < len(regions)
else None
)
if is_digital and text_body_idxs:
raw_body = self._extract_article_body_bbox(
pdf_path, page_num, text_body_idxs, regions,
headline_bbox=headline_bbox,
)
body = self._format_body_text(raw_body)
else:
parts = [regions[i]["text"].strip() for i in text_body_idxs]
body = "\n\n".join(parts)
# --- Metadata from individual regions ---
subheadline = None
if subheadline_idxs:
subs = [" ".join(regions[i]["text"].split()) for i in subheadline_idxs]
subheadline = " | ".join(subs)
byline = None
dateline = None
if byline_idxs:
bl_text = regions[byline_idxs[0]]["text"].strip()
parts = re.split(r"\n+", bl_text)
byline = parts[0].strip() if parts else None
if len(parts) > 1:
dateline = parts[-1].strip()
caption = None
if caption_idxs:
caps = [regions[i]["text"].strip() for i in caption_idxs]
caption = " | ".join(caps)
# --- Continuation ---
is_continued = False
continued_on = None
for idx in continuation_idxs:
m = re.search(r"PAGE\s+(\d+)", regions[idx]["text"], re.IGNORECASE)
if m:
is_continued = True
continued_on = int(m.group(1))
break
if not is_continued:
m = re.search(r"CONTINUED\s+ON\s+.*?PAGE\s+(\d+)", body, re.IGNORECASE)
if m:
is_continued = True
continued_on = int(m.group(1))
body = body[: m.start()].strip()
# --- Cleanup ---
# Remove headline from body
if headline:
h_norm = " ".join(headline.split()).lower()
body_lines = body.split("\n")
body_lines = [
ln for ln in body_lines
if not self._fuzzy_match(" ".join(ln.split()).lower(), h_norm, 0.85)
]
body = "\n".join(body_lines).strip()
# Remove metadata text from body
for idx_list in (subheadline_idxs, byline_idxs, caption_idxs, continuation_idxs):
for idx in idx_list:
if idx < len(regions):
body = body.replace(regions[idx]["text"].strip(), "").strip()
# Deduplicate paragraphs
paras = body.split("\n\n") if "\n\n" in body else body.split("\n")
deduped = []
for p in paras:
p = p.strip()
if not p:
continue
p_norm = " ".join(p.split())
is_dup = False
for existing in deduped:
if self._fuzzy_match(p_norm, " ".join(existing.split())):
if len(p) > len(existing):
deduped.remove(existing)
deduped.append(p)
is_dup = True
break
if not is_dup:
deduped.append(p)
body = "\n\n".join(deduped)
body = re.sub(r"\n{3,}", "\n\n", body).strip()
if not headline and not body:
continue
source_regions = []
if headline_idx is not None:
source_regions.append(headline_idx)
source_regions.extend(body_idxs)
articles.append({
"headline": headline,
"subheadline": subheadline,
"byline": byline,
"dateline": dateline,
"body": body,
"caption": caption,
"category": category,
"is_continued": is_continued,
"continued_on_page": continued_on,
"source_regions": source_regions,
})
logger.info(f"Assembled {len(articles)} articles")
return {"articles": articles}