Spaces:
Sleeping
Sleeping
| # ============================================================================= | |
| # 📰 Newspaper Article Extractor — Core Pipeline | |
| # No UI dependencies. Can be used standalone: | |
| # from extractor import ExtractionPipeline | |
| # pipeline = ExtractionPipeline(api_key="...") | |
| # result = pipeline.extract(pdf_path, page_num=0) | |
| # ============================================================================= | |
| import json | |
| import time | |
| import re | |
| import base64 | |
| import fitz | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| from ultralytics import YOLO | |
| from huggingface_hub import snapshot_download | |
| from openai import OpenAI | |
| from difflib import SequenceMatcher | |
| import io | |
| import os | |
| import logging | |
| from config import ( | |
| DPI, SCALE_FACTOR, LLM_BASE_URL, VISION_MODEL, | |
| YOLO_REPO, YOLO_CONF_THRESHOLD, | |
| SKIP_CLASSES, HEADLINE_CLASSES, BODY_CLASSES, | |
| BBOX_PADDING, HEADLINE_DISTANCE_FACTOR, HEADLINE_MIN_DISTANCE, | |
| GROUPING_PROMPT, | |
| ) | |
| logger = logging.getLogger("newspaper_extractor") | |
| # EasyOCR — lazy loaded only for scanned PDFs | |
| _ocr_reader = None | |
| def _get_ocr_reader(): | |
| global _ocr_reader | |
| if _ocr_reader is None: | |
| import easyocr | |
| logger.info("Loading EasyOCR fallback...") | |
| _ocr_reader = easyocr.Reader(['en'], gpu=False) | |
| return _ocr_reader | |
| # ============================================================================= | |
| # Pipeline class | |
| # ============================================================================= | |
| class ExtractionPipeline: | |
| """Main extraction pipeline. Initialize once, call extract() per page.""" | |
| def __init__(self, api_key, cache_dir="/tmp/hf_cache"): | |
| # Load YOLO model | |
| logger.info("Loading YOLO model...") | |
| repo_path = snapshot_download(YOLO_REPO, cache_dir=cache_dir) | |
| self.yolo_model = YOLO(os.path.join(repo_path, "weights/best.pt")) | |
| logger.info(f"YOLO classes: {self.yolo_model.names}") | |
| # LLM client | |
| self.llm_client = OpenAI(base_url=LLM_BASE_URL, api_key=api_key) | |
| logger.info("✅ Pipeline initialized") | |
| # ----------------------------------------------------------------- | |
| # Public API | |
| # ----------------------------------------------------------------- | |
| def extract(self, pdf_path, page_num=0): | |
| """ | |
| Extract articles from a single PDF page. | |
| Returns: (result_dict, viz_image, regions, is_digital) | |
| """ | |
| is_digital = self._is_digital_pdf(pdf_path, page_num) | |
| image, total_pages = self._pdf_page_to_image(pdf_path, page_num) | |
| if image is None: | |
| return None, None, None, is_digital, total_pages | |
| regions = self._detect_layout(image) | |
| viz_image = self._visualize_layout(image, regions) | |
| img_b64 = self._create_numbered_image(image, regions) | |
| self._extract_region_texts(regions, pdf_path, page_num, image, is_digital) | |
| grouping = self._group_regions(img_b64, regions) | |
| result = self._assemble_articles(grouping, regions, pdf_path, page_num, image, is_digital) | |
| return result, viz_image, regions, is_digital, total_pages | |
| def get_page_count(self, pdf_path): | |
| """Return total page count of a PDF.""" | |
| doc = fitz.open(pdf_path) | |
| count = doc.page_count | |
| doc.close() | |
| return count | |
| # ----------------------------------------------------------------- | |
| # PDF helpers | |
| # ----------------------------------------------------------------- | |
| def _is_digital_pdf(self, pdf_path, page_num=0): | |
| doc = fitz.open(pdf_path) | |
| page = doc[page_num] | |
| text = page.get_text("text").strip() | |
| doc.close() | |
| is_digital = len(text) > 500 | |
| logger.info(f"PDF type: {'Digital' if is_digital else 'Scanned'} ({len(text)} chars)") | |
| return is_digital | |
| def _pdf_page_to_image(self, pdf_path, page_num): | |
| doc = fitz.open(pdf_path) | |
| if page_num < 0 or page_num >= doc.page_count: | |
| total = doc.page_count | |
| doc.close() | |
| return None, total | |
| page = doc[page_num] | |
| mat = fitz.Matrix(SCALE_FACTOR, SCALE_FACTOR) | |
| pix = page.get_pixmap(matrix=mat) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| total = doc.page_count | |
| doc.close() | |
| return img, total | |
| def _pixel_bbox_to_pdf_rect(bbox): | |
| x1, y1, x2, y2 = bbox | |
| return fitz.Rect( | |
| x1 / SCALE_FACTOR, y1 / SCALE_FACTOR, | |
| x2 / SCALE_FACTOR, y2 / SCALE_FACTOR, | |
| ) | |
| # ----------------------------------------------------------------- | |
| # Layout detection | |
| # ----------------------------------------------------------------- | |
| def _detect_layout(self, image): | |
| results = self.yolo_model.predict( | |
| source=image, conf=YOLO_CONF_THRESHOLD, imgsz=1024, verbose=False, | |
| ) | |
| regions = [] | |
| for result in results: | |
| boxes = result.boxes | |
| for i in range(len(boxes)): | |
| x1, y1, x2, y2 = boxes.xyxy[i].tolist() | |
| conf = boxes.conf[i].item() | |
| cls_id = int(boxes.cls[i].item()) | |
| cls_name = result.names[cls_id] | |
| regions.append({ | |
| "bbox": [int(x1), int(y1), int(x2), int(y2)], | |
| "class": cls_name, | |
| "confidence": round(conf, 3), | |
| }) | |
| regions.sort(key=lambda r: (r["bbox"][1], r["bbox"][0])) | |
| logger.info(f"Detected {len(regions)} regions") | |
| return regions | |
| # ----------------------------------------------------------------- | |
| # Visualization | |
| # ----------------------------------------------------------------- | |
| def _visualize_layout(image, regions): | |
| img_copy = image.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| colors = { | |
| "title": "#E24B4A", "text": "#378ADD", "picture": "#639922", | |
| "figure": "#639922", "table": "#BA7517", "caption": "#1D9E75", | |
| "section-header": "#E24B4A", "header": "#888780", "footer": "#888780", | |
| } | |
| for i, r in enumerate(regions): | |
| x1, y1, x2, y2 = r["bbox"] | |
| color = colors.get(r["class"].lower(), "#888780") | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) | |
| label = f"[{i}] {r['class']}" | |
| draw.rectangle([x1, y1, x1 + len(label) * 7, y1 + 16], fill=color) | |
| draw.text((x1 + 2, y1 + 1), label, fill="white") | |
| return img_copy | |
| def _create_numbered_image(image, regions): | |
| img_copy = image.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| try: | |
| font = ImageFont.truetype( | |
| "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28, | |
| ) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| for i, r in enumerate(regions): | |
| x1, y1, x2, y2 = r["bbox"] | |
| draw.rectangle([x1, y1, x2, y2], outline="red", width=2) | |
| draw.rectangle([x1, y1, x1 + 32, y1 + 32], fill="red") | |
| draw.text((x1 + 6, y1 + 2), str(i), fill="white", font=font) | |
| img_copy.thumbnail((1000, 1000)) | |
| buf = io.BytesIO() | |
| img_copy.save(buf, format="JPEG", quality=55) | |
| return base64.b64encode(buf.getvalue()).decode() | |
| # ----------------------------------------------------------------- | |
| # Text extraction | |
| # ----------------------------------------------------------------- | |
| def _extract_region_text_pymupdf(self, pdf_path, page_num, bbox): | |
| doc = fitz.open(pdf_path) | |
| page = doc[page_num] | |
| text = page.get_text("text", clip=self._pixel_bbox_to_pdf_rect(bbox)).strip() | |
| doc.close() | |
| return text | |
| def _extract_region_text_ocr(image, bbox, region_class): | |
| reader = _get_ocr_reader() | |
| x1, y1, x2, y2 = bbox | |
| pad = 5 | |
| x1, y1 = max(0, x1 - pad), max(0, y1 - pad) | |
| x2, y2 = min(image.width, x2 + pad), min(image.height, y2 + pad) | |
| crop_np = np.array(image.crop((x1, y1, x2, y2))) | |
| threshold = 0.5 if region_class.lower() in HEADLINE_CLASSES else 0.4 | |
| results = reader.readtext(crop_np, paragraph=True, text_threshold=threshold) | |
| return " ".join([r[1] for r in results]).strip() | |
| def _extract_region_texts(self, regions, pdf_path, page_num, image, is_digital): | |
| """Extract text for all regions (used for LLM summary).""" | |
| for r in regions: | |
| if r["class"].lower() in SKIP_CLASSES: | |
| r["text"] = "" | |
| r["text_source"] = "skipped" | |
| continue | |
| if is_digital: | |
| text = self._extract_region_text_pymupdf(pdf_path, page_num, r["bbox"]) | |
| if len(text) > 3: | |
| r["text"] = text | |
| r["text_source"] = "pymupdf" | |
| continue | |
| r["text"] = self._extract_region_text_ocr(image, r["bbox"], r["class"]) | |
| r["text_source"] = "ocr" | |
| def _extract_article_body_bbox(self, pdf_path, page_num, body_idxs, regions, | |
| headline_bbox=None): | |
| """Extract body text from bounding box in PDF content stream order.""" | |
| if not body_idxs: | |
| return "" | |
| valid = body_idxs | |
| # Headline-based horizontal constraint | |
| if headline_bbox: | |
| h_x1, _, h_x2, _ = headline_bbox | |
| h_center = (h_x1 + h_x2) / 2 | |
| h_width = h_x2 - h_x1 | |
| max_dist = max(h_width * HEADLINE_DISTANCE_FACTOR, HEADLINE_MIN_DISTANCE) | |
| filtered = [ | |
| i for i in valid | |
| if abs((regions[i]["bbox"][0] + regions[i]["bbox"][2]) / 2 - h_center) | |
| <= max_dist | |
| ] | |
| if filtered: | |
| valid = filtered | |
| x1 = min(regions[i]["bbox"][0] for i in valid) + abs(BBOX_PADDING) | |
| y1 = min(regions[i]["bbox"][1] for i in valid) + abs(BBOX_PADDING) | |
| x2 = max(regions[i]["bbox"][2] for i in valid) - abs(BBOX_PADDING) | |
| y2 = max(regions[i]["bbox"][3] for i in valid) - abs(BBOX_PADDING) | |
| pdf_rect = fitz.Rect( | |
| x1 / SCALE_FACTOR, y1 / SCALE_FACTOR, | |
| x2 / SCALE_FACTOR, y2 / SCALE_FACTOR, | |
| ) | |
| doc = fitz.open(pdf_path) | |
| page = doc[page_num] | |
| text = page.get_text("text", clip=pdf_rect).strip() | |
| doc.close() | |
| return text | |
| # ----------------------------------------------------------------- | |
| # Vision LLM | |
| # ----------------------------------------------------------------- | |
| def _call_vision_llm(self, img_b64, prompt, max_retries=3): | |
| for attempt in range(max_retries): | |
| try: | |
| resp = self.llm_client.chat.completions.create( | |
| model=VISION_MODEL, | |
| messages=[{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", | |
| "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }], | |
| temperature=0.1, | |
| max_tokens=2048, | |
| ) | |
| return resp.choices[0].message.content | |
| except Exception as e: | |
| if "429" in str(e) or "rate" in str(e).lower(): | |
| wait = 60 | |
| m = re.search(r"(\d+\.?\d*)\s*s", str(e)) | |
| if m: | |
| wait = float(m.group(1)) + 2 | |
| logger.warning(f"Rate limited, waiting {wait:.0f}s (attempt {attempt + 1})") | |
| time.sleep(wait) | |
| continue | |
| raise | |
| raise RuntimeError("Vision LLM failed after retries") | |
| def _group_regions(self, img_b64, regions): | |
| lines = [] | |
| for i, r in enumerate(regions): | |
| text = r.get("text", "") | |
| if not text: | |
| lines.append(f"[{i}] {r['class']} — (no text / image)") | |
| else: | |
| preview = text[:80].replace("\n", " ") | |
| lines.append(f"[{i}] {r['class']} — \"{preview}\"") | |
| prompt = GROUPING_PROMPT.format(region_summary="\n".join(lines)) | |
| raw = self._call_vision_llm(img_b64, prompt).strip() | |
| if raw.startswith("```"): | |
| raw = raw.split("\n", 1)[1].rsplit("```", 1)[0] | |
| grouping = json.loads(raw) | |
| # Log orphans (informational — bounding box compensates) | |
| assigned = set() | |
| for art in grouping.get("articles", []): | |
| if art.get("headline_region") is not None: | |
| assigned.add(art["headline_region"]) | |
| assigned.update(art.get("body_regions", [])) | |
| assigned.update(grouping.get("discarded_regions", [])) | |
| orphaned_text = [ | |
| i for i in set(range(len(regions))) - assigned | |
| if regions[i].get("text") | |
| ] | |
| if orphaned_text: | |
| logger.info( | |
| f"{len(orphaned_text)} unassigned text regions " | |
| "(bounding box will capture them)" | |
| ) | |
| logger.info(f"Grouped into {len(grouping.get('articles', []))} articles") | |
| return grouping | |
| # ----------------------------------------------------------------- | |
| # Text formatting | |
| # ----------------------------------------------------------------- | |
| def _format_body_text(raw_text): | |
| """ | |
| Convert PyMuPDF raw output into clean paragraphed text. | |
| PyMuPDF returns text with line breaks at every visual line end | |
| in the PDF column. This function: | |
| 1. Rejoins hyphenated words split across lines | |
| 2. Joins lines within the same paragraph | |
| 3. Detects paragraph breaks (sentence end + next line starts uppercase) | |
| """ | |
| if not raw_text: | |
| return "" | |
| lines = raw_text.split("\n") | |
| paragraphs = [] | |
| current = [] | |
| for line in lines: | |
| line = line.rstrip() | |
| # Skip empty lines | |
| if not line: | |
| if current: | |
| paragraphs.append(" ".join(current)) | |
| current = [] | |
| continue | |
| # If current paragraph buffer has content, check for paragraph break | |
| if current: | |
| last = current[-1] | |
| # Rejoin hyphenated word: "ap-\npointed" → "appointed" | |
| if last.endswith("-"): | |
| current[-1] = last[:-1] | |
| current.append(line) | |
| continue | |
| # Paragraph break: previous line ends sentence + new line starts uppercase | |
| ends_sentence = last.rstrip().endswith((".", '"', "'", "?", "!")) | |
| starts_upper = line.lstrip()[:1].isupper() | |
| if ends_sentence and starts_upper: | |
| paragraphs.append(" ".join(current)) | |
| current = [line] | |
| continue | |
| current.append(line) | |
| # Flush remaining | |
| if current: | |
| paragraphs.append(" ".join(current)) | |
| # Clean up each paragraph | |
| cleaned = [] | |
| for para in paragraphs: | |
| # Collapse multiple spaces | |
| para = re.sub(r"\s{2,}", " ", para).strip() | |
| if para: | |
| cleaned.append(para) | |
| return "\n\n".join(cleaned) | |
| # ----------------------------------------------------------------- | |
| # Assembly | |
| # ----------------------------------------------------------------- | |
| def _fuzzy_match(a, b, threshold=0.8): | |
| if not a or not b: | |
| return False | |
| shorter, longer = (a, b) if len(a) < len(b) else (b, a) | |
| if shorter in longer: | |
| return True | |
| if len(shorter) / len(longer) < 0.5: | |
| return False | |
| return SequenceMatcher(None, a[:200], b[:200]).ratio() > threshold | |
| def _assemble_articles(self, grouping, regions, pdf_path, page_num, | |
| image, is_digital): | |
| articles = [] | |
| for group in grouping.get("articles", []): | |
| headline_idx = group.get("headline_region") | |
| body_idxs = group.get("body_regions", []) | |
| category = group.get("category") | |
| # Headline | |
| headline = "" | |
| if headline_idx is not None and headline_idx < len(regions): | |
| headline = " ".join(regions[headline_idx].get("text", "").split()) | |
| # Classify body regions by role | |
| text_body_idxs = [] | |
| subheadline_idxs = [] | |
| byline_idxs = [] | |
| caption_idxs = [] | |
| continuation_idxs = [] | |
| for idx in body_idxs: | |
| if idx >= len(regions): | |
| continue | |
| r = regions[idx] | |
| text = r.get("text", "").strip() | |
| cls = r["class"].lower() | |
| if not text: | |
| continue | |
| if cls in HEADLINE_CLASSES: | |
| subheadline_idxs.append(idx) | |
| elif re.search(r"CONTINUED\s+ON", text, re.IGNORECASE): | |
| continuation_idxs.append(idx) | |
| elif re.match( | |
| r"^[A-Z][a-z]+(?:\s+[A-Z][a-z]+){1,3}\s*$", | |
| text.split("\n")[0], | |
| ): | |
| byline_idxs.append(idx) | |
| elif cls == "caption": | |
| caption_idxs.append(idx) | |
| else: | |
| text_body_idxs.append(idx) | |
| # --- Body text --- | |
| headline_bbox = ( | |
| regions[headline_idx]["bbox"] | |
| if headline_idx is not None and headline_idx < len(regions) | |
| else None | |
| ) | |
| if is_digital and text_body_idxs: | |
| raw_body = self._extract_article_body_bbox( | |
| pdf_path, page_num, text_body_idxs, regions, | |
| headline_bbox=headline_bbox, | |
| ) | |
| body = self._format_body_text(raw_body) | |
| else: | |
| parts = [regions[i]["text"].strip() for i in text_body_idxs] | |
| body = "\n\n".join(parts) | |
| # --- Metadata from individual regions --- | |
| subheadline = None | |
| if subheadline_idxs: | |
| subs = [" ".join(regions[i]["text"].split()) for i in subheadline_idxs] | |
| subheadline = " | ".join(subs) | |
| byline = None | |
| dateline = None | |
| if byline_idxs: | |
| bl_text = regions[byline_idxs[0]]["text"].strip() | |
| parts = re.split(r"\n+", bl_text) | |
| byline = parts[0].strip() if parts else None | |
| if len(parts) > 1: | |
| dateline = parts[-1].strip() | |
| caption = None | |
| if caption_idxs: | |
| caps = [regions[i]["text"].strip() for i in caption_idxs] | |
| caption = " | ".join(caps) | |
| # --- Continuation --- | |
| is_continued = False | |
| continued_on = None | |
| for idx in continuation_idxs: | |
| m = re.search(r"PAGE\s+(\d+)", regions[idx]["text"], re.IGNORECASE) | |
| if m: | |
| is_continued = True | |
| continued_on = int(m.group(1)) | |
| break | |
| if not is_continued: | |
| m = re.search(r"CONTINUED\s+ON\s+.*?PAGE\s+(\d+)", body, re.IGNORECASE) | |
| if m: | |
| is_continued = True | |
| continued_on = int(m.group(1)) | |
| body = body[: m.start()].strip() | |
| # --- Cleanup --- | |
| # Remove headline from body | |
| if headline: | |
| h_norm = " ".join(headline.split()).lower() | |
| body_lines = body.split("\n") | |
| body_lines = [ | |
| ln for ln in body_lines | |
| if not self._fuzzy_match(" ".join(ln.split()).lower(), h_norm, 0.85) | |
| ] | |
| body = "\n".join(body_lines).strip() | |
| # Remove metadata text from body | |
| for idx_list in (subheadline_idxs, byline_idxs, caption_idxs, continuation_idxs): | |
| for idx in idx_list: | |
| if idx < len(regions): | |
| body = body.replace(regions[idx]["text"].strip(), "").strip() | |
| # Deduplicate paragraphs | |
| paras = body.split("\n\n") if "\n\n" in body else body.split("\n") | |
| deduped = [] | |
| for p in paras: | |
| p = p.strip() | |
| if not p: | |
| continue | |
| p_norm = " ".join(p.split()) | |
| is_dup = False | |
| for existing in deduped: | |
| if self._fuzzy_match(p_norm, " ".join(existing.split())): | |
| if len(p) > len(existing): | |
| deduped.remove(existing) | |
| deduped.append(p) | |
| is_dup = True | |
| break | |
| if not is_dup: | |
| deduped.append(p) | |
| body = "\n\n".join(deduped) | |
| body = re.sub(r"\n{3,}", "\n\n", body).strip() | |
| if not headline and not body: | |
| continue | |
| source_regions = [] | |
| if headline_idx is not None: | |
| source_regions.append(headline_idx) | |
| source_regions.extend(body_idxs) | |
| articles.append({ | |
| "headline": headline, | |
| "subheadline": subheadline, | |
| "byline": byline, | |
| "dateline": dateline, | |
| "body": body, | |
| "caption": caption, | |
| "category": category, | |
| "is_continued": is_continued, | |
| "continued_on_page": continued_on, | |
| "source_regions": source_regions, | |
| }) | |
| logger.info(f"Assembled {len(articles)} articles") | |
| return {"articles": articles} |