PDF-Parser / app.py
saifisvibin's picture
Update app.py
62fe271 verified
import json
import os
import shutil
import threading
import uuid
from pathlib import Path
from typing import Dict, List, Optional
from flask import Flask, render_template, request, jsonify, send_file, send_from_directory
from flask_cors import CORS
from werkzeug.utils import secure_filename
import torch
import main as extractor
from loguru import logger
app = Flask(__name__)
# Enable CORS for all routes
CORS(app, resources={r"/api/*": {"origins": "*"}})
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 # 500MB max file size
app.config['UPLOAD_FOLDER'] = './uploads'
app.config['OUTPUT_FOLDER'] = './output'
# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True)
# Global model instance
_model = None
# Progress tracking: {task_id: {'status': 'processing'|'completed'|'error', 'progress': 0-100, 'message': str, 'results': [], 'file_progress': {filename: progress}}}
_progress_tracker: Dict[str, Dict] = {}
_progress_lock = threading.RLock() # Use RLock for reentrant locking
def get_device_info() -> Dict[str, any]:
"""Get information about GPU/CPU availability."""
cuda_available = torch.cuda.is_available()
device = "cuda" if cuda_available else "cpu"
info = {
"device": device,
"cuda_available": cuda_available,
"device_name": None,
"device_count": 0,
}
if cuda_available:
info["device_name"] = torch.cuda.get_device_name(0)
info["device_count"] = torch.cuda.device_count()
return info
def load_model_once():
"""Load the model once and cache it."""
global _model
if _model is None:
logger.info("Loading DocLayout-YOLO model...")
_model = extractor.get_model()
logger.info("Model loaded successfully")
return _model
@app.route('/')
def index():
"""Main page."""
device_info = get_device_info()
return render_template('index.html', device_info=device_info)
@app.route('/api/docs')
def api_docs():
"""API documentation page showing all available endpoints."""
routes = []
for rule in app.url_map.iter_rules():
if rule.rule.startswith('/api') or rule.rule.startswith('/output'):
methods = ','.join(sorted(rule.methods - {'OPTIONS', 'HEAD'}))
func = app.view_functions.get(rule.endpoint)
doc = func.__doc__ if func and hasattr(func, '__doc__') else 'No description'
routes.append({
'endpoint': rule.rule,
'methods': methods,
'description': doc.strip() if doc else 'No description'
})
# Force HTTPS for Hugging Face Spaces (always use HTTPS)
base_url = request.host_url.rstrip('/')
if base_url.startswith('http://'):
base_url = base_url.replace('http://', 'https://')
return render_template('api_docs.html', routes=routes, base_url=base_url)
@app.route('/api/predict', methods=['POST', 'GET'])
def predict():
# Handle GET requests with info message
if request.method == 'GET':
return jsonify({
'status': 'info',
'message': 'This endpoint accepts POST requests only. Please use POST method with a PDF file in the "file" field.',
'usage': {
'method': 'POST',
'content_type': 'multipart/form-data',
'body': {
'file': 'PDF file to process'
},
'example_curl': 'curl -X POST https://saifisvibin-volaris-pdf-tool.hf.space/api/predict -F "file=@document.pdf"'
}
}), 405
"""
Clean REST API endpoint for PDF extraction.
Accepts a PDF file and returns extracted text, tables, and figures.
Request:
- Method: POST
- Content-Type: multipart/form-data
- Body: file (PDF file)
Response:
{
"status": "success",
"filename": "document.pdf",
"text": "extracted markdown text...",
"tables": [...],
"figures": [...],
"summary": {...}
}
"""
try:
# Check if file is present
if 'file' not in request.files:
return jsonify({
'status': 'error',
'error': 'No file provided. Please upload a PDF file using the "file" field.'
}), 400
file = request.files['file']
if file.filename == '':
return jsonify({
'status': 'error',
'error': 'No file selected'
}), 400
if not file.filename.lower().endswith('.pdf'):
return jsonify({
'status': 'error',
'error': 'Invalid file type. Please upload a PDF file.'
}), 400
filename = secure_filename(file.filename)
stem = Path(filename).stem
# Create a permanent output directory for this request (using stem + timestamp for uniqueness)
import time
unique_id = f"{stem}_{int(time.time())}"
output_dir = Path(app.config['OUTPUT_FOLDER']) / unique_id
output_dir.mkdir(parents=True, exist_ok=True)
# Create temporary upload directory
temp_upload = Path(app.config['UPLOAD_FOLDER']) / f"temp_{uuid.uuid4().hex}"
temp_upload.mkdir(parents=True, exist_ok=True)
try:
# Save uploaded file
pdf_path = temp_upload / filename
file_data = file.read()
pdf_path.write_bytes(file_data)
# Load model if needed
load_model_once()
# Process PDF (extract both images and markdown)
extractor.USE_MULTIPROCESSING = False
extractor.process_pdf_with_pool(
pdf_path,
output_dir,
pool=None,
extract_images=True,
extract_markdown=True,
)
# Collect extracted data
result = {
'status': 'success',
'filename': filename,
'text': '',
'tables': [],
'figures': [],
'summary': {
'total_pages': 0,
'figures_count': 0,
'tables_count': 0,
'elements_count': 0
}
}
# Extract markdown text
markdown_path = output_dir / f"{stem}.md"
if markdown_path.exists():
result['text'] = markdown_path.read_text(encoding='utf-8')
# Get base URL for constructing full image URLs
base_url = request.host_url.rstrip('/')
if 'hf.space' in base_url:
# Force HTTPS for Hugging Face Spaces
base_url = base_url.replace('http://', 'https://')
# Extract figures and tables from JSON
json_path = output_dir / f"{stem}_content_list.json"
if json_path.exists():
elements = json.loads(json_path.read_text(encoding='utf-8'))
figures = [e for e in elements if e.get('type') == 'figure']
tables = [e for e in elements if e.get('type') == 'table']
# Get page count
try:
import pypdfium2 as pdfium
pdf_bytes = pdf_path.read_bytes()
doc = pdfium.PdfDocument(pdf_bytes)
result['summary']['total_pages'] = len(doc)
doc.close()
except:
pass
# Format figures
for fig in figures:
figure_data = {
'page': fig.get('page', 0),
'bbox': fig.get('bbox_pixels', []),
'confidence': fig.get('conf', 0.0),
'width': fig.get('width', 0),
'height': fig.get('height', 0),
}
# Include image URL if available
if fig.get('image_path'):
img_path = output_dir / fig['image_path']
if img_path.exists():
# Get relative path from OUTPUT_FOLDER
relative_path = str(img_path.relative_to(app.config['OUTPUT_FOLDER']))
# Construct full URL
image_url = f"{base_url}/output/{relative_path}"
figure_data['image_url'] = image_url
figure_data['image_path'] = relative_path
result['figures'].append(figure_data)
# Format tables
for tab in tables:
table_data = {
'page': tab.get('page', 0),
'bbox': tab.get('bbox_pixels', []),
'confidence': tab.get('conf', 0.0),
'width': tab.get('width', 0),
'height': tab.get('height', 0),
}
# Include image URL if available
if tab.get('image_path'):
img_path = output_dir / tab['image_path']
if img_path.exists():
# Get relative path from OUTPUT_FOLDER
relative_path = str(img_path.relative_to(app.config['OUTPUT_FOLDER']))
# Construct full URL
image_url = f"{base_url}/output/{relative_path}"
table_data['image_url'] = image_url
table_data['image_path'] = relative_path
result['tables'].append(table_data)
result['summary']['figures_count'] = len(figures)
result['summary']['tables_count'] = len(tables)
result['summary']['elements_count'] = len(elements)
return jsonify(result)
finally:
# Clean up temporary upload directory only (keep output_dir for file access)
try:
if temp_upload.exists():
if temp_upload.is_file():
temp_upload.unlink()
else:
shutil.rmtree(temp_upload, ignore_errors=True)
except Exception as e:
logger.warning(f"Error cleaning up temp upload files: {e}")
except Exception as e:
logger.error(f"Error in /api/predict: {e}")
import traceback
logger.error(traceback.format_exc())
return jsonify({
'status': 'error',
'error': str(e)
}), 500
@app.route('/api/device-info')
def device_info():
"""API endpoint to get device information."""
return jsonify(get_device_info())
def _update_task_progress(task_id: str, filename: str, file_progress: int, message: str):
"""Update progress for a specific file and calculate overall progress."""
with _progress_lock:
if task_id not in _progress_tracker:
return
# Update file-specific progress
if 'file_progress' not in _progress_tracker[task_id]:
_progress_tracker[task_id]['file_progress'] = {}
_progress_tracker[task_id]['file_progress'][filename] = file_progress
# Calculate overall progress (average of all files)
file_progresses = _progress_tracker[task_id]['file_progress']
if file_progresses:
total_progress = sum(file_progresses.values()) / len(file_progresses)
_progress_tracker[task_id]['progress'] = int(total_progress)
_progress_tracker[task_id]['message'] = message
def process_file_background(task_id: str, file_data: bytes, filename: str, extraction_mode: str):
"""Process a single file in the background and update progress."""
filename = secure_filename(filename)
try:
_update_task_progress(task_id, filename, 5, f'Processing {filename}...')
stem = Path(filename).stem
include_images = extraction_mode != 'markdown'
include_markdown = extraction_mode != 'images'
# Ensure upload directory exists
upload_dir = Path(app.config['UPLOAD_FOLDER'])
upload_dir.mkdir(parents=True, exist_ok=True)
# Save uploaded file from bytes data
upload_path = upload_dir / filename
upload_path.write_bytes(file_data)
_update_task_progress(task_id, filename, 15, f'Saved {filename}, preparing output...')
# Prepare output directory
output_dir = Path(app.config['OUTPUT_FOLDER']) / stem
output_dir.mkdir(parents=True, exist_ok=True)
# Copy PDF to output directory, using replace to overwrite if it already exists
pdf_path = output_dir / filename
upload_path.replace(pdf_path)
_update_task_progress(task_id, filename, 25, f'Loading model and processing {filename}...')
# Process PDF
extractor.USE_MULTIPROCESSING = False
logger.info(f"Processing {filename} (images={include_images}, markdown={include_markdown})")
if include_images:
try:
load_model_once()
logger.info(f"Model loaded successfully for {filename}")
except Exception as model_error:
logger.error(f"Failed to load model for {filename}: {model_error}")
import traceback
logger.error(traceback.format_exc())
raise Exception(f"Model loading failed: {str(model_error)}. The processing service may be unavailable.")
_update_task_progress(task_id, filename, 30, f'Extracting content from {filename}...')
extractor.process_pdf_with_pool(
pdf_path,
output_dir,
pool=None,
extract_images=include_images,
extract_markdown=include_markdown,
)
_update_task_progress(task_id, filename, 85, f'Collecting results for {filename}...')
# Collect results
json_path = output_dir / f"{stem}_content_list.json"
elements = []
if include_images and json_path.exists():
elements = json.loads(json_path.read_text(encoding='utf-8'))
annotated_pdf = None
if include_images:
candidate_pdf = output_dir / f"{stem}_layout.pdf"
if candidate_pdf.exists():
annotated_pdf = str(candidate_pdf.relative_to(app.config['OUTPUT_FOLDER']))
markdown_path = None
if include_markdown:
candidate_md = output_dir / f"{stem}.md"
if candidate_md.exists():
markdown_path = str(candidate_md.relative_to(app.config['OUTPUT_FOLDER']))
# Get figure and table counts
figures = [e for e in elements if e.get('type') == 'figure']
tables = [e for e in elements if e.get('type') == 'table']
# Get base URL for constructing full URLs
# Note: We can't use request.host_url here since we're in a background thread
# So we'll construct URLs that will be completed in the API endpoint
result = {
'filename': filename,
'stem': stem,
'output_dir': str(output_dir.relative_to(app.config['OUTPUT_FOLDER'])),
'figures_count': len(figures),
'tables_count': len(tables),
'elements_count': len(elements),
'annotated_pdf': annotated_pdf,
'markdown_path': markdown_path,
'include_images': include_images,
'include_markdown': include_markdown,
}
with _progress_lock:
# Update file progress to 100%
if 'file_progress' not in _progress_tracker[task_id]:
_progress_tracker[task_id]['file_progress'] = {}
_progress_tracker[task_id]['file_progress'][filename] = 100
# Calculate overall progress
file_progresses = _progress_tracker[task_id]['file_progress']
if file_progresses:
total_progress = sum(file_progresses.values()) / len(file_progresses)
_progress_tracker[task_id]['progress'] = int(total_progress)
# Add result
_progress_tracker[task_id]['results'].append(result)
_progress_tracker[task_id]['message'] = f'Completed processing {filename}'
# Check if all files are done
total_files = _progress_tracker[task_id].get('total_files', 1)
completed_count = len([r for r in _progress_tracker[task_id]['results'] if 'error' not in r])
error_count = len([r for r in _progress_tracker[task_id]['results'] if 'error' in r])
if completed_count + error_count >= total_files:
if error_count == 0:
_progress_tracker[task_id]['status'] = 'completed'
_progress_tracker[task_id]['progress'] = 100
_progress_tracker[task_id]['message'] = f'All {total_files} file(s) processed successfully'
else:
_progress_tracker[task_id]['status'] = 'completed' # Still mark as completed even with some errors
_progress_tracker[task_id]['message'] = f'Processing complete: {completed_count} succeeded, {error_count} failed'
except Exception as e:
logger.error(f"Error processing {filename}: {e}")
import traceback
logger.error(traceback.format_exc())
with _progress_lock:
_progress_tracker[task_id]['results'].append({
'filename': filename,
'error': str(e)
})
# Check if this was the last file
total_files = _progress_tracker[task_id].get('total_files', 1)
if len(_progress_tracker[task_id]['results']) >= total_files:
_progress_tracker[task_id]['status'] = 'error'
_progress_tracker[task_id]['message'] = f'Error processing {filename}: {str(e)}'
@app.route('/api/upload', methods=['POST'])
def upload_files():
"""Handle multiple PDF file uploads with sequential background processing."""
if 'files[]' not in request.files:
return jsonify({'error': 'No files provided'}), 400
files = request.files.getlist('files[]')
extraction_mode = request.form.get('extraction_mode', 'both')
if not files or all(f.filename == '' for f in files):
return jsonify({'error': 'No files selected'}), 400
# Read all file data eagerly before threads start (request context will close)
file_payloads = []
for file in files:
if file and file.filename.endswith('.pdf'):
data = file.read()
if data:
file_payloads.append((file.filename, data))
else:
logger.warning(f"Empty file skipped: {file.filename}")
if not file_payloads:
return jsonify({'error': 'No valid PDF files could be read'}), 400
# Create a task ID for this upload batch
task_id = str(uuid.uuid4())
with _progress_lock:
_progress_tracker[task_id] = {
'status': 'processing',
'progress': 0,
'message': f'Queued {len(file_payloads)} file(s) for processing...',
'results': [],
'total_files': len(file_payloads),
}
def process_queue():
"""Process all files sequentially in a single background thread."""
total = len(file_payloads)
for idx, (filename, file_data) in enumerate(file_payloads, start=1):
with _progress_lock:
_progress_tracker[task_id]['message'] = f'Processing file {idx} of {total}: {filename}'
try:
process_file_background(task_id, file_data, filename, extraction_mode)
except Exception as e:
logger.error(f"Unhandled error processing {filename}: {e}")
import traceback
logger.error(traceback.format_exc())
with _progress_lock:
_progress_tracker[task_id]['results'].append({
'filename': filename,
'error': str(e)
})
# Final status update after all files are done
with _progress_lock:
tracker = _progress_tracker[task_id]
good = [r for r in tracker['results'] if 'error' not in r]
bad = [r for r in tracker['results'] if 'error' in r]
tracker['status'] = 'completed'
tracker['progress'] = 100
if bad:
tracker['message'] = f'{len(good)} succeeded, {len(bad)} failed.'
else:
tracker['message'] = f'All {total} file(s) processed successfully.'
thread = threading.Thread(target=process_queue)
thread.daemon = True
thread.start()
logger.info(f"Started sequential processing queue for {len(file_payloads)} file(s), task={task_id}")
return jsonify({
'task_id': task_id,
'message': 'Processing started',
'total_files': len(file_payloads)
})
@app.route('/api/progress/<task_id>')
def get_progress(task_id):
"""Get progress for a processing task."""
with _progress_lock:
progress = _progress_tracker.get(task_id)
if not progress:
return jsonify({'error': 'Task not found'}), 404
# Get base URL for constructing full URLs
base_url = request.host_url.rstrip('/')
if 'hf.space' in base_url:
# Force HTTPS for Hugging Face Spaces
base_url = base_url.replace('http://', 'https://')
# Add full URLs to results if they exist
if 'results' in progress:
for result in progress['results']:
# Add full URL for annotated PDF
if result.get('annotated_pdf'):
result['annotated_pdf_url'] = f"{base_url}/output/{result['annotated_pdf']}"
# Add full URL for markdown
if result.get('markdown_path'):
result['markdown_url'] = f"{base_url}/output/{result['markdown_path']}"
# Add image URLs for figures and tables if available
output_dir = Path(app.config['OUTPUT_FOLDER']) / result.get('stem', '')
if output_dir.exists():
# Load content list to get figure and table image paths
json_files = list(output_dir.glob('*_content_list.json'))
if json_files:
try:
elements = json.loads(json_files[0].read_text(encoding='utf-8'))
figures = [e for e in elements if e.get('type') == 'figure']
tables = [e for e in elements if e.get('type') == 'table']
# Add figure URLs
figure_urls = []
for fig in figures:
if fig.get('image_path'):
img_path = output_dir / fig['image_path']
if img_path.exists():
relative_path = str(img_path.relative_to(app.config['OUTPUT_FOLDER']))
figure_urls.append({
'page': fig.get('page', 0),
'url': f"{base_url}/output/{relative_path}",
'path': relative_path
})
# Add table URLs
table_urls = []
for tab in tables:
if tab.get('image_path'):
img_path = output_dir / tab['image_path']
if img_path.exists():
relative_path = str(img_path.relative_to(app.config['OUTPUT_FOLDER']))
table_urls.append({
'page': tab.get('page', 0),
'url': f"{base_url}/output/{relative_path}",
'path': relative_path
})
if figure_urls:
result['figure_urls'] = figure_urls
if table_urls:
result['table_urls'] = table_urls
except Exception as e:
logger.warning(f"Error loading image URLs for {result.get('stem')}: {e}")
return jsonify(progress)
@app.route('/api/pdf-list')
def pdf_list():
"""Get list of processed PDFs."""
output_dir = Path(app.config['OUTPUT_FOLDER'])
pdfs = []
for item in output_dir.iterdir():
if item.is_dir():
# Check if this directory has processed content
json_files = list(item.glob('*_content_list.json'))
md_files = list(item.glob('*.md'))
pdf_files = list(item.glob('*.pdf'))
if json_files or md_files or pdf_files:
stem = item.name
pdfs.append({
'stem': stem,
'output_dir': str(item.relative_to(app.config['OUTPUT_FOLDER'])),
})
return jsonify({'pdfs': pdfs})
@app.route('/api/pdf-details/<path:pdf_stem>')
def pdf_details(pdf_stem):
"""Get detailed information about a processed PDF."""
output_dir = Path(app.config['OUTPUT_FOLDER']) / pdf_stem
if not output_dir.exists():
return jsonify({'error': 'PDF not found'}), 404
# Get base URL for constructing full URLs
base_url = request.host_url.rstrip('/')
if 'hf.space' in base_url:
# Force HTTPS for Hugging Face Spaces
base_url = base_url.replace('http://', 'https://')
# Load content list
json_files = list(output_dir.glob('*_content_list.json'))
elements = []
if json_files:
elements = json.loads(json_files[0].read_text(encoding='utf-8'))
# Get figures and tables
figures = [e for e in elements if e.get('type') == 'figure']
tables = [e for e in elements if e.get('type') == 'table']
# Get file paths
annotated_pdf = None
pdf_files = list(output_dir.glob('*_layout.pdf'))
if pdf_files:
annotated_pdf = str(pdf_files[0].relative_to(app.config['OUTPUT_FOLDER']))
markdown_path = None
md_files = list(output_dir.glob('*.md'))
if md_files:
markdown_path = str(md_files[0].relative_to(app.config['OUTPUT_FOLDER']))
# Get figure and table images
figure_dir = output_dir / 'figures'
table_dir = output_dir / 'tables'
figure_images = []
if figure_dir.exists():
figure_images = [str(f.relative_to(app.config['OUTPUT_FOLDER']))
for f in sorted(figure_dir.glob('*.png'))]
table_images = []
if table_dir.exists():
table_images = [str(t.relative_to(app.config['OUTPUT_FOLDER']))
for t in sorted(table_dir.glob('*.png'))]
return jsonify({
'stem': pdf_stem,
'figures': figures,
'tables': tables,
'figures_count': len(figures),
'tables_count': len(tables),
'elements_count': len(elements),
'annotated_pdf': annotated_pdf,
'markdown_path': markdown_path,
'figure_images': figure_images,
'table_images': table_images,
# Add full URLs for direct access
'urls': {
'annotated_pdf': f"{base_url}/output/{annotated_pdf}" if annotated_pdf else None,
'markdown': f"{base_url}/output/{markdown_path}" if markdown_path else None,
'figures': [f"{base_url}/output/{img}" for img in figure_images] if figure_images else [],
'tables': [f"{base_url}/output/{img}" for img in table_images] if table_images else [],
},
# Keep relative paths for backward compatibility
'download_urls': {
'annotated_pdf': f"/output/{annotated_pdf}" if annotated_pdf else None,
'markdown': f"/output/{markdown_path}" if markdown_path else None,
'figures': [f"/output/{img}" for img in figure_images] if figure_images else [],
'tables': [f"/output/{img}" for img in table_images] if table_images else [],
}
})
@app.route('/output/<path:filename>')
def output_file(filename):
"""Serve output files (PDFs, images, markdown)."""
try:
output_folder = Path(app.config['OUTPUT_FOLDER']).resolve()
file_path = (output_folder / filename).resolve()
# Security: Prevent path traversal attacks
if not str(file_path).startswith(str(output_folder)):
return jsonify({'error': 'Invalid file path'}), 400
# Check if file exists
if not file_path.exists():
return jsonify({
'error': 'File not found',
'requested_path': filename,
'hint': 'Use /api/pdf-details/<stem> to get correct file paths'
}), 404
if not file_path.is_file():
return jsonify({'error': 'Path is not a file'}), 400
# Determine MIME type based on extension
mime_types = {
'.pdf': 'application/pdf',
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.md': 'text/markdown',
'.json': 'application/json',
'.txt': 'text/plain'
}
ext = file_path.suffix.lower()
mimetype = mime_types.get(ext, 'application/octet-stream')
return send_file(str(file_path), mimetype=mimetype, as_attachment=False)
except Exception as e:
logger.error(f"Error serving file {filename}: {e}")
return jsonify({
'error': 'Failed to serve file',
'message': str(e)
}), 500
def _delete_by_stem(stem_raw: str):
stem = (stem_raw or "").strip()
if not stem:
return jsonify({'error': 'Missing stem'}), 400
# Resolve output directory safely
output_root = Path(app.config['OUTPUT_FOLDER']).resolve()
target_dir = (output_root / stem).resolve()
# Prevent path traversal - ensure target is within output_root
if output_root not in target_dir.parents and target_dir != output_root:
return jsonify({'error': 'Invalid stem path'}), 400
if not target_dir.exists() or not target_dir.is_dir():
return jsonify({'error': 'Not found'}), 404
# Delete the directory
shutil.rmtree(target_dir, ignore_errors=False)
logger.info(f"Deleted processed output: {target_dir}")
return jsonify({'ok': True, 'deleted': stem})
@app.route('/api/delete', methods=['POST'])
def delete_pdf():
"""Delete a processed PDF directory by stem (JSON or form body)."""
try:
data = request.get_json(silent=True) or {}
stem = (data.get('stem') or request.form.get('stem') or '').strip()
return _delete_by_stem(stem)
except Exception as e:
logger.error(f"Delete failed: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/delete/<path:stem>', methods=['POST', 'GET'])
def delete_pdf_by_path(stem: str):
"""Alternate endpoint to delete using URL path, for clients avoiding bodies."""
try:
return _delete_by_stem(stem)
except Exception as e:
logger.error(f"Delete failed: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/download-zip/<path:stem>', methods=['GET'])
def download_zip(stem: str):
"""Download the processed output as a zip archive."""
import io
import zipfile
stem = stem.strip()
if not stem:
return jsonify({'error': 'Missing stem'}), 400
output_root = Path(app.config['OUTPUT_FOLDER']).resolve()
target_dir = (output_root / stem).resolve()
# Prevent path traversal
if output_root not in target_dir.parents and target_dir != output_root:
return jsonify({'error': 'Invalid stem path'}), 400
if not target_dir.exists() or not target_dir.is_dir():
return jsonify({'error': 'PDF not found or not processed completely'}), 404
try:
memory_file = io.BytesIO()
with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(target_dir):
for file in files:
file_path = Path(root) / file
arcname = file_path.relative_to(target_dir)
zf.write(file_path, arcname)
memory_file.seek(0)
return send_file(
memory_file,
mimetype='application/zip',
as_attachment=True,
download_name=f"{stem}_extracted.zip"
)
except Exception as e:
logger.error(f"Zip creation failed: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/download-all/<task_id>', methods=['GET'])
def download_all(task_id: str):
"""Download all output directories for a task as a single ZIP archive."""
import io, zipfile
with _progress_lock:
tracker = _progress_tracker.get(task_id)
if not tracker:
return jsonify({'error': 'Task not found'}), 404
stems = [r.get('stem') for r in tracker.get('results', []) if r.get('stem')]
if not stems:
return jsonify({'error': 'No processed files found for this task'}), 404
output_root = Path(app.config['OUTPUT_FOLDER']).resolve()
memory_file = io.BytesIO()
try:
with zipfile.ZipFile(memory_file, 'w', zipfile.ZIP_DEFLATED) as zf:
for stem in stems:
target_dir = (output_root / stem).resolve()
# Prevent path traversal
if output_root not in target_dir.parents and target_dir != output_root:
continue
if not target_dir.exists():
continue
for root, _, files in os.walk(target_dir):
for file in files:
file_path = Path(root) / file
# Archive under stem/filename so files don't collide
arcname = Path(stem) / file_path.relative_to(target_dir)
zf.write(file_path, arcname)
memory_file.seek(0)
return send_file(
memory_file,
mimetype='application/zip',
as_attachment=True,
download_name='all_extracted.zip'
)
except Exception as e:
logger.error(f"Download-all zip creation failed: {e}")
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Run on port 7860 for Hugging Face Spaces, or 5000 for local development
port = int(os.environ.get('PORT', 7860))
app.run(debug=False, host='0.0.0.0', port=port)