saifisvibinn commited on
Commit
cd71fbc
·
1 Parent(s): 27a4694

Add /api/predict endpoint implementation

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py CHANGED
@@ -89,6 +89,196 @@ def api_docs():
89
  return render_template('api_docs.html', routes=routes, base_url=base_url)
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  @app.route('/api/device-info')
93
  def device_info():
94
  """API endpoint to get device information."""
 
89
  return render_template('api_docs.html', routes=routes, base_url=base_url)
90
 
91
 
92
+ @app.route('/api/predict', methods=['POST'])
93
+ def predict():
94
+ """
95
+ Clean REST API endpoint for PDF extraction.
96
+ Accepts a PDF file and returns extracted text, tables, and figures.
97
+
98
+ Request:
99
+ - Method: POST
100
+ - Content-Type: multipart/form-data
101
+ - Body: file (PDF file)
102
+
103
+ Response:
104
+ {
105
+ "status": "success",
106
+ "filename": "document.pdf",
107
+ "text": "extracted markdown text...",
108
+ "tables": [...],
109
+ "figures": [...],
110
+ "summary": {...}
111
+ }
112
+ """
113
+ try:
114
+ # Check if file is present
115
+ if 'file' not in request.files:
116
+ return jsonify({
117
+ 'status': 'error',
118
+ 'error': 'No file provided. Please upload a PDF file using the "file" field.'
119
+ }), 400
120
+
121
+ file = request.files['file']
122
+
123
+ if file.filename == '':
124
+ return jsonify({
125
+ 'status': 'error',
126
+ 'error': 'No file selected'
127
+ }), 400
128
+
129
+ if not file.filename.lower().endswith('.pdf'):
130
+ return jsonify({
131
+ 'status': 'error',
132
+ 'error': 'Invalid file type. Please upload a PDF file.'
133
+ }), 400
134
+
135
+ filename = secure_filename(file.filename)
136
+ stem = Path(filename).stem
137
+
138
+ # Create temporary directories for processing
139
+ temp_upload = Path(app.config['UPLOAD_FOLDER']) / f"temp_{uuid.uuid4().hex}"
140
+ temp_output = Path(app.config['OUTPUT_FOLDER']) / f"temp_{uuid.uuid4().hex}"
141
+ temp_upload.parent.mkdir(parents=True, exist_ok=True)
142
+ temp_output.mkdir(parents=True, exist_ok=True)
143
+
144
+ try:
145
+ # Save uploaded file
146
+ pdf_path = temp_upload / filename
147
+ file_data = file.read()
148
+ pdf_path.write_bytes(file_data)
149
+
150
+ # Load model if needed
151
+ load_model_once()
152
+
153
+ # Process PDF (extract both images and markdown)
154
+ extractor.USE_MULTIPROCESSING = False
155
+ extractor.process_pdf_with_pool(
156
+ pdf_path,
157
+ temp_output,
158
+ pool=None,
159
+ extract_images=True,
160
+ extract_markdown=True,
161
+ )
162
+
163
+ # Collect extracted data
164
+ result = {
165
+ 'status': 'success',
166
+ 'filename': filename,
167
+ 'text': '',
168
+ 'tables': [],
169
+ 'figures': [],
170
+ 'summary': {
171
+ 'total_pages': 0,
172
+ 'figures_count': 0,
173
+ 'tables_count': 0,
174
+ 'elements_count': 0
175
+ }
176
+ }
177
+
178
+ # Extract markdown text
179
+ markdown_path = temp_output / f"{stem}.md"
180
+ if markdown_path.exists():
181
+ result['text'] = markdown_path.read_text(encoding='utf-8')
182
+
183
+ # Extract figures and tables from JSON
184
+ json_path = temp_output / f"{stem}_content_list.json"
185
+ if json_path.exists():
186
+ elements = json.loads(json_path.read_text(encoding='utf-8'))
187
+
188
+ figures = [e for e in elements if e.get('type') == 'figure']
189
+ tables = [e for e in elements if e.get('type') == 'table']
190
+
191
+ # Get page count
192
+ try:
193
+ import pypdfium2 as pdfium
194
+ pdf_bytes = pdf_path.read_bytes()
195
+ doc = pdfium.PdfDocument(pdf_bytes)
196
+ result['summary']['total_pages'] = len(doc)
197
+ doc.close()
198
+ except:
199
+ pass
200
+
201
+ # Format figures
202
+ for fig in figures:
203
+ figure_data = {
204
+ 'page': fig.get('page', 0),
205
+ 'bbox': fig.get('bbox_pixels', []),
206
+ 'confidence': fig.get('conf', 0.0),
207
+ 'width': fig.get('width', 0),
208
+ 'height': fig.get('height', 0),
209
+ }
210
+
211
+ # Include image path if available
212
+ if fig.get('image_path'):
213
+ img_path = temp_output / fig['image_path']
214
+ if img_path.exists():
215
+ # Convert image to base64 for API response
216
+ from PIL import Image
217
+ import io
218
+ img = Image.open(img_path)
219
+ img_buffer = io.BytesIO()
220
+ img.save(img_buffer, format='PNG')
221
+ img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
222
+ figure_data['image_base64'] = f"data:image/png;base64,{img_base64}"
223
+ figure_data['image_path'] = fig['image_path']
224
+
225
+ result['figures'].append(figure_data)
226
+
227
+ # Format tables
228
+ for tab in tables:
229
+ table_data = {
230
+ 'page': tab.get('page', 0),
231
+ 'bbox': tab.get('bbox_pixels', []),
232
+ 'confidence': tab.get('conf', 0.0),
233
+ 'width': tab.get('width', 0),
234
+ 'height': tab.get('height', 0),
235
+ }
236
+
237
+ # Include image path if available
238
+ if tab.get('image_path'):
239
+ img_path = temp_output / tab['image_path']
240
+ if img_path.exists():
241
+ # Convert image to base64 for API response
242
+ from PIL import Image
243
+ import io
244
+ img = Image.open(img_path)
245
+ img_buffer = io.BytesIO()
246
+ img.save(img_buffer, format='PNG')
247
+ img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
248
+ table_data['image_base64'] = f"data:image/png;base64,{img_base64}"
249
+ table_data['image_path'] = tab['image_path']
250
+
251
+ result['tables'].append(table_data)
252
+
253
+ result['summary']['figures_count'] = len(figures)
254
+ result['summary']['tables_count'] = len(tables)
255
+ result['summary']['elements_count'] = len(elements)
256
+
257
+ return jsonify(result)
258
+
259
+ finally:
260
+ # Clean up temporary files
261
+ try:
262
+ if temp_upload.exists():
263
+ if temp_upload.is_file():
264
+ temp_upload.unlink()
265
+ else:
266
+ shutil.rmtree(temp_upload, ignore_errors=True)
267
+ if temp_output.exists():
268
+ shutil.rmtree(temp_output, ignore_errors=True)
269
+ except Exception as e:
270
+ logger.warning(f"Error cleaning up temp files: {e}")
271
+
272
+ except Exception as e:
273
+ logger.error(f"Error in /api/predict: {e}")
274
+ import traceback
275
+ logger.error(traceback.format_exc())
276
+ return jsonify({
277
+ 'status': 'error',
278
+ 'error': str(e)
279
+ }), 500
280
+
281
+
282
  @app.route('/api/device-info')
283
  def device_info():
284
  """API endpoint to get device information."""