Raghu commited on
Commit
f83c4f6
·
1 Parent(s): 617e9b4

Deploy receipt processing app

Browse files

- Add app.py (Gradio app with ensemble classifier, OCR, field extraction, anomaly detection)
- Add requirements.txt
- Add models directory with .pt weights
- Update README metadata

README.md CHANGED
@@ -1,14 +1,58 @@
1
  ---
2
- title: Receipt Agent
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 6.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Agentic implementation of multi-ensemble receipt automation
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Receipt Processing Agent
3
+ emoji: 🧾
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
+ # Receipt Processing Agent
14
+
15
+ An intelligent document processing pipeline that automatically classifies receipts, extracts key fields, detects anomalies, and makes routing decisions.
16
+
17
+ ## Features
18
+
19
+ - **Document Classification**: ViT + ResNet18 ensemble (100% accuracy)
20
+ - **OCR**: EasyOCR with confidence visualization
21
+ - **Field Extraction**: Vendor, date, total extraction
22
+ - **Anomaly Detection**: Rule-based suspicious pattern detection
23
+ - **Decision Routing**: APPROVE / REVIEW / REJECT
24
+
25
+ ## How It Works
26
+
27
+ 1. **Upload** a receipt image
28
+ 2. **Classification** determines if it's actually a receipt
29
+ 3. **OCR** extracts all text with bounding boxes
30
+ 4. **Field Extraction** identifies vendor, date, and total
31
+ 5. **Anomaly Detection** checks for suspicious patterns
32
+ 6. **Routing** decides: approve, send for review, or reject
33
+
34
+ ## Model Details
35
+
36
+ | Component | Model | Performance |
37
+ |-----------|-------|-------------|
38
+ | Classification | ViT-Tiny + ResNet18 | 100% accuracy |
39
+ | OCR | EasyOCR | 74% avg confidence |
40
+ | Field Extraction | Regex patterns | 79% F1 |
41
+ | Anomaly Detection | Rule-based | 100% accuracy |
42
+
43
+ ## Full Pipeline
44
+
45
+ This is a simplified demo. The complete system includes:
46
+ - LayoutLMv3 for advanced field extraction
47
+ - 4-model anomaly detection ensemble (IsolationForest + XGBoost + HistGB + SVM)
48
+ - LangGraph agentic workflow with conditional branching
49
+ - Human feedback loop with automatic model fine-tuning
50
+
51
+ ## Repository
52
+
53
+ Full code and documentation: [GitHub](https://github.com/RogueTex/StreamingDataforModelTraining)
54
+
55
+ ## License
56
+
57
+ MIT
58
+
app.py ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Receipt Processing Pipeline - Hugging Face Spaces App
3
+ Ensemble classification, OCR, field extraction, anomaly detection, and agentic routing.
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ import gradio as gr
11
+ import easyocr
12
+ import json
13
+ import re
14
+ from PIL import Image, ImageDraw
15
+ from datetime import datetime
16
+ from torchvision import transforms, models
17
+ from transformers import ViTForImageClassification, ViTImageProcessor
18
+ from sklearn.ensemble import IsolationForest
19
+ import warnings
20
+ warnings.filterwarnings('ignore')
21
+
22
+ # ============================================================================
23
+ # Configuration
24
+ # ============================================================================
25
+
26
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+ MODELS_DIR = 'models'
28
+
29
+ print(f"Device: {DEVICE}")
30
+ print(f"Models directory: {MODELS_DIR}")
31
+
32
+ # ============================================================================
33
+ # Model Classes
34
+ # ============================================================================
35
+
36
+ class DocumentClassifier:
37
+ """ViT-based document classifier (receipt vs other)."""
38
+
39
+ def __init__(self, num_labels=2, model_path=None):
40
+ self.num_labels = num_labels
41
+ self.model = None
42
+ self.processor = None
43
+ self.model_path = model_path or os.path.join(MODELS_DIR, 'rvl_classifier.pt')
44
+ self.pretrained = 'WinKawaks/vit-tiny-patch16-224'
45
+
46
+ def load_model(self):
47
+ try:
48
+ self.processor = ViTImageProcessor.from_pretrained(self.pretrained)
49
+ except:
50
+ self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
51
+
52
+ self.model = ViTForImageClassification.from_pretrained(
53
+ self.pretrained,
54
+ num_labels=self.num_labels,
55
+ ignore_mismatched_sizes=True
56
+ )
57
+ self.model = self.model.to(DEVICE)
58
+ self.model.eval()
59
+ return self.model
60
+
61
+ def load_weights(self, path):
62
+ if os.path.exists(path):
63
+ checkpoint = torch.load(path, map_location=DEVICE)
64
+ if isinstance(checkpoint, dict):
65
+ if 'model_state_dict' in checkpoint:
66
+ self.model.load_state_dict(checkpoint['model_state_dict'], strict=False)
67
+ elif 'state_dict' in checkpoint:
68
+ self.model.load_state_dict(checkpoint['state_dict'], strict=False)
69
+ else:
70
+ self.model.load_state_dict(checkpoint, strict=False)
71
+ else:
72
+ self.model.load_state_dict(checkpoint, strict=False)
73
+ print(f" Loaded ViT weights from {path}")
74
+
75
+ def predict(self, image):
76
+ if self.model is None:
77
+ self.load_model()
78
+
79
+ self.model.eval()
80
+ if not isinstance(image, Image.Image):
81
+ image = Image.fromarray(image)
82
+ image = image.convert('RGB')
83
+
84
+ inputs = self.processor(images=image, return_tensors="pt")
85
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
86
+
87
+ with torch.no_grad():
88
+ outputs = self.model(**inputs)
89
+ probs = torch.softmax(outputs.logits, dim=-1)
90
+ pred = torch.argmax(probs, dim=-1).item()
91
+ conf = probs[0, pred].item()
92
+
93
+ is_receipt = pred == 1
94
+ label = "receipt" if is_receipt else "other"
95
+
96
+ return {
97
+ 'is_receipt': is_receipt,
98
+ 'confidence': conf,
99
+ 'label': label,
100
+ 'probabilities': probs[0].cpu().numpy().tolist()
101
+ }
102
+
103
+
104
+ class ResNetDocumentClassifier:
105
+ """ResNet18-based document classifier."""
106
+
107
+ def __init__(self, num_labels=2, model_path=None):
108
+ self.num_labels = num_labels
109
+ self.model = None
110
+ self.model_path = model_path or os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt')
111
+ self.use_class_mapping = False
112
+
113
+ self.transform = transforms.Compose([
114
+ transforms.Resize((224, 224)),
115
+ transforms.ToTensor(),
116
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
117
+ ])
118
+
119
+ def load_model(self):
120
+ self.model = models.resnet18(weights=None)
121
+ self.model = self.model.to(DEVICE)
122
+ self.model.eval()
123
+ return self.model
124
+
125
+ def load_weights(self, path):
126
+ if not os.path.exists(path):
127
+ return
128
+
129
+ checkpoint = torch.load(path, map_location=DEVICE)
130
+
131
+ if isinstance(checkpoint, dict):
132
+ state_dict = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
133
+ id2label = checkpoint.get('id2label', None)
134
+ else:
135
+ state_dict = checkpoint
136
+ id2label = None
137
+
138
+ # Determine number of classes from checkpoint
139
+ fc_weight_key = 'fc.weight'
140
+ if fc_weight_key in state_dict:
141
+ num_classes = state_dict[fc_weight_key].shape[0]
142
+ else:
143
+ num_classes = self.num_labels
144
+
145
+ # Rebuild final layer if needed
146
+ if num_classes != self.model.fc.out_features:
147
+ self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
148
+ self.model = self.model.to(DEVICE)
149
+
150
+ self.model.load_state_dict(state_dict, strict=False)
151
+
152
+ # Handle 16-class RVL-CDIP models
153
+ if num_classes == 16:
154
+ self.use_class_mapping = True
155
+ self.receipt_class_idx = 11 # Receipt class in RVL-CDIP
156
+
157
+ print(f" Loaded ResNet weights from {path} ({num_classes} classes)")
158
+
159
+ def predict(self, image):
160
+ if self.model is None:
161
+ self.load_model()
162
+
163
+ self.model.eval()
164
+ if not isinstance(image, Image.Image):
165
+ image = Image.fromarray(image)
166
+ image = image.convert('RGB')
167
+
168
+ input_tensor = self.transform(image).unsqueeze(0).to(DEVICE)
169
+
170
+ with torch.no_grad():
171
+ outputs = self.model(input_tensor)
172
+ probs = torch.softmax(outputs, dim=-1)
173
+
174
+ if self.use_class_mapping:
175
+ receipt_prob = probs[0, self.receipt_class_idx].item()
176
+ other_prob = 1.0 - receipt_prob
177
+ is_receipt = receipt_prob > 0.5
178
+ conf = receipt_prob if is_receipt else other_prob
179
+ final_probs = [other_prob, receipt_prob]
180
+ else:
181
+ pred = torch.argmax(probs, dim=-1).item()
182
+ conf = probs[0, pred].item()
183
+ is_receipt = pred == 1
184
+ final_probs = probs[0].cpu().numpy().tolist()
185
+
186
+ return {
187
+ 'is_receipt': is_receipt,
188
+ 'confidence': conf,
189
+ 'label': "receipt" if is_receipt else "other",
190
+ 'probabilities': final_probs
191
+ }
192
+
193
+
194
+ class EnsembleDocumentClassifier:
195
+ """Ensemble of ViT and ResNet classifiers."""
196
+
197
+ def __init__(self, model_configs=None, weights=None):
198
+ self.model_configs = model_configs or [
199
+ {'name': 'vit_base', 'path': os.path.join(MODELS_DIR, 'rvl_classifier.pt')},
200
+ {'name': 'resnet18', 'path': os.path.join(MODELS_DIR, 'resnet18_rvlcdip.pt')},
201
+ ]
202
+
203
+ # Filter to existing models
204
+ self.model_configs = [cfg for cfg in self.model_configs if os.path.exists(cfg['path'])]
205
+
206
+ if not self.model_configs:
207
+ print("Warning: No model files found, will use default ViT")
208
+ self.model_configs = [{'name': 'vit_default', 'path': None}]
209
+
210
+ self.weights = weights or [1.0 / len(self.model_configs)] * len(self.model_configs)
211
+ self.classifiers = []
212
+ self.processor = None
213
+
214
+ def load_models(self):
215
+ print(f"Loading ensemble with {len(self.model_configs)} models...")
216
+
217
+ for cfg in self.model_configs:
218
+ is_resnet = 'resnet' in cfg['name'].lower() or 'resnet' in cfg.get('path', '').lower()
219
+
220
+ if is_resnet:
221
+ classifier = ResNetDocumentClassifier(num_labels=2, model_path=cfg['path'])
222
+ else:
223
+ classifier = DocumentClassifier(num_labels=2, model_path=cfg['path'])
224
+
225
+ classifier.load_model()
226
+
227
+ if cfg['path'] and os.path.exists(cfg['path']):
228
+ try:
229
+ classifier.load_weights(cfg['path'])
230
+ except Exception as e:
231
+ print(f" Warning: Could not load {cfg['name']}: {e}")
232
+
233
+ self.classifiers.append(classifier)
234
+
235
+ if self.processor is None:
236
+ if hasattr(classifier, 'processor'):
237
+ self.processor = classifier.processor
238
+ elif hasattr(classifier, 'transform'):
239
+ self.processor = classifier.transform
240
+
241
+ print(f"Ensemble ready with {len(self.classifiers)} models")
242
+ return self
243
+
244
+ def predict(self, image, return_individual=False):
245
+ if not self.classifiers:
246
+ self.load_models()
247
+
248
+ all_probs = []
249
+ individual_results = []
250
+
251
+ for i, classifier in enumerate(self.classifiers):
252
+ result = classifier.predict(image)
253
+ probs = result.get('probabilities', [0.5, 0.5])
254
+ if len(probs) < 2:
255
+ probs = [1 - result['confidence'], result['confidence']]
256
+ all_probs.append(probs)
257
+ individual_results.append({
258
+ 'name': self.model_configs[i]['name'],
259
+ 'prediction': result['label'],
260
+ 'confidence': result['confidence'],
261
+ 'probabilities': probs
262
+ })
263
+
264
+ # Weighted average
265
+ ensemble_probs = np.zeros(2)
266
+ for i, probs in enumerate(all_probs):
267
+ ensemble_probs += np.array(probs[:2]) * self.weights[i]
268
+
269
+ pred = np.argmax(ensemble_probs)
270
+ is_receipt = pred == 1
271
+ conf = ensemble_probs[pred]
272
+
273
+ result = {
274
+ 'is_receipt': is_receipt,
275
+ 'confidence': float(conf),
276
+ 'label': "receipt" if is_receipt else "other",
277
+ 'probabilities': ensemble_probs.tolist()
278
+ }
279
+
280
+ if return_individual:
281
+ result['individual_results'] = individual_results
282
+
283
+ return result
284
+
285
+
286
+ # ============================================================================
287
+ # OCR
288
+ # ============================================================================
289
+
290
+ class ReceiptOCR:
291
+ """EasyOCR wrapper with retry logic."""
292
+
293
+ def __init__(self):
294
+ self.reader = None
295
+
296
+ def load(self):
297
+ if self.reader is None:
298
+ print("Loading EasyOCR...")
299
+ self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
300
+ print("EasyOCR ready")
301
+ return self
302
+
303
+ def extract_with_positions(self, image, min_confidence=0.3):
304
+ if self.reader is None:
305
+ self.load()
306
+
307
+ if isinstance(image, Image.Image):
308
+ image = np.array(image)
309
+
310
+ results = self.reader.readtext(image)
311
+
312
+ extracted = []
313
+ for bbox, text, conf in results:
314
+ if conf >= min_confidence:
315
+ x_coords = [p[0] for p in bbox]
316
+ y_coords = [p[1] for p in bbox]
317
+ extracted.append({
318
+ 'text': text,
319
+ 'confidence': conf,
320
+ 'bbox': [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
321
+ })
322
+
323
+ return extracted
324
+
325
+ def postprocess_receipt(self, ocr_results):
326
+ """Extract structured fields from OCR results."""
327
+ full_text = ' '.join([r['text'] for r in ocr_results])
328
+
329
+ fields = {
330
+ 'vendor': self._extract_vendor(ocr_results),
331
+ 'date': self._extract_date(full_text),
332
+ 'total': self._extract_total(full_text),
333
+ 'time': self._extract_time(full_text)
334
+ }
335
+
336
+ return fields
337
+
338
+ def _extract_vendor(self, ocr_results):
339
+ if ocr_results:
340
+ # Usually first line is vendor
341
+ return ocr_results[0]['text']
342
+ return None
343
+
344
+ def _extract_date(self, text):
345
+ patterns = [
346
+ r'\d{1,2}/\d{1,2}/\d{2,4}',
347
+ r'\d{1,2}-\d{1,2}-\d{2,4}',
348
+ r'\d{4}-\d{2}-\d{2}',
349
+ ]
350
+ for pattern in patterns:
351
+ match = re.search(pattern, text)
352
+ if match:
353
+ return match.group()
354
+ return None
355
+
356
+ def _extract_total(self, text):
357
+ patterns = [
358
+ r'TOTAL[:\s]*\$?(\d+\.?\d*)',
359
+ r'AMOUNT[:\s]*\$?(\d+\.?\d*)',
360
+ r'DUE[:\s]*\$?(\d+\.?\d*)',
361
+ ]
362
+ for pattern in patterns:
363
+ match = re.search(pattern, text, re.IGNORECASE)
364
+ if match:
365
+ return match.group(1)
366
+
367
+ # Find largest dollar amount
368
+ amounts = re.findall(r'\$(\d+\.\d{2})', text)
369
+ if amounts:
370
+ return max(amounts, key=float)
371
+ return None
372
+
373
+ def _extract_time(self, text):
374
+ pattern = r'\d{1,2}:\d{2}(?::\d{2})?(?:\s*[AP]M)?'
375
+ match = re.search(pattern, text, re.IGNORECASE)
376
+ return match.group() if match else None
377
+
378
+
379
+ # ============================================================================
380
+ # Anomaly Detection
381
+ # ============================================================================
382
+
383
+ class AnomalyDetector:
384
+ """Isolation Forest-based anomaly detection."""
385
+
386
+ def __init__(self):
387
+ self.model = IsolationForest(contamination=0.1, random_state=42)
388
+ self.is_fitted = False
389
+
390
+ def extract_features(self, fields):
391
+ """Extract features from receipt fields."""
392
+ total = 0
393
+ try:
394
+ total = float(fields.get('total', 0) or 0)
395
+ except:
396
+ pass
397
+
398
+ vendor = fields.get('vendor', '') or ''
399
+ date = fields.get('date', '') or ''
400
+
401
+ features = [
402
+ total,
403
+ np.log1p(total),
404
+ len(vendor),
405
+ 1 if date else 0,
406
+ 1, # num_items placeholder
407
+ 12, # hour placeholder
408
+ total, # amount_per_item placeholder
409
+ 0 # is_weekend placeholder
410
+ ]
411
+
412
+ return np.array(features).reshape(1, -1)
413
+
414
+ def predict(self, fields):
415
+ features = self.extract_features(fields)
416
+
417
+ # Simple rule-based detection if model not fitted
418
+ reasons = []
419
+ total = float(fields.get('total', 0) or 0)
420
+
421
+ if total > 1000:
422
+ reasons.append(f"High amount: ${total:.2f}")
423
+ if not fields.get('vendor'):
424
+ reasons.append("Missing vendor")
425
+ if not fields.get('date'):
426
+ reasons.append("Missing date")
427
+
428
+ is_anomaly = len(reasons) > 0
429
+
430
+ return {
431
+ 'is_anomaly': is_anomaly,
432
+ 'score': -0.5 if is_anomaly else 0.5,
433
+ 'prediction': 'ANOMALY' if is_anomaly else 'NORMAL',
434
+ 'reasons': reasons
435
+ }
436
+
437
+
438
+ # ============================================================================
439
+ # Initialize Models
440
+ # ============================================================================
441
+
442
+ print("\n" + "="*50)
443
+ print("Initializing models...")
444
+ print("="*50)
445
+
446
+ # Check for model files
447
+ model_files = []
448
+ if os.path.exists(MODELS_DIR):
449
+ model_files = [f for f in os.listdir(MODELS_DIR) if f.endswith('.pt')]
450
+ print(f"Found model files: {model_files}")
451
+ else:
452
+ print(f"Models directory not found: {MODELS_DIR}")
453
+ os.makedirs(MODELS_DIR, exist_ok=True)
454
+
455
+ # Initialize components
456
+ try:
457
+ ensemble_classifier = EnsembleDocumentClassifier()
458
+ ensemble_classifier.load_models()
459
+ except Exception as e:
460
+ print(f"Warning: Could not load ensemble classifier: {e}")
461
+ ensemble_classifier = None
462
+
463
+ try:
464
+ receipt_ocr = ReceiptOCR()
465
+ receipt_ocr.load()
466
+ except Exception as e:
467
+ print(f"Warning: Could not load OCR: {e}")
468
+ receipt_ocr = None
469
+
470
+ anomaly_detector = AnomalyDetector()
471
+
472
+ print("\n" + "="*50)
473
+ print("Initialization complete!")
474
+ print("="*50 + "\n")
475
+
476
+
477
+ # ============================================================================
478
+ # Helper Functions
479
+ # ============================================================================
480
+
481
+ def draw_ocr_boxes(image, ocr_results):
482
+ """Draw bounding boxes on image."""
483
+ img_copy = image.copy()
484
+ draw = ImageDraw.Draw(img_copy)
485
+
486
+ for r in ocr_results:
487
+ conf = r.get('confidence', 0.5)
488
+ bbox = r.get('bbox', [])
489
+
490
+ if conf > 0.8:
491
+ color = '#28a745' # Green
492
+ elif conf > 0.5:
493
+ color = '#ffc107' # Yellow
494
+ else:
495
+ color = '#dc3545' # Red
496
+
497
+ if len(bbox) >= 4:
498
+ draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], outline=color, width=2)
499
+
500
+ return img_copy
501
+
502
+
503
+ def process_receipt(image):
504
+ """Main processing function for Gradio."""
505
+ if image is None:
506
+ return (
507
+ "<div style='padding: 20px; text-align: center;'>Upload an image to begin</div>",
508
+ None, "", "", ""
509
+ )
510
+
511
+ if not isinstance(image, Image.Image):
512
+ image = Image.fromarray(image)
513
+ image = image.convert('RGB')
514
+
515
+ results = {}
516
+
517
+ # 1. Classification
518
+ classifier_html = ""
519
+ try:
520
+ if ensemble_classifier:
521
+ class_result = ensemble_classifier.predict(image, return_individual=True)
522
+ else:
523
+ class_result = {'is_receipt': True, 'confidence': 0.5, 'label': 'unknown'}
524
+
525
+ conf = class_result['confidence']
526
+ label = class_result['label'].upper()
527
+ color = '#28a745' if class_result.get('is_receipt') else '#dc3545'
528
+ bar_color = '#28a745' if conf > 0.8 else '#ffc107' if conf > 0.6 else '#dc3545'
529
+
530
+ classifier_html = f"""
531
+ <div style="padding: 16px; background: #f8f9fa; border-radius: 12px; margin: 8px 0;">
532
+ <h4 style="margin: 0 0 8px 0;">Classification</h4>
533
+ <div style="font-size: 20px; font-weight: bold; color: {color};">{label}</div>
534
+ <div style="margin-top: 8px;">
535
+ <span>Confidence: </span>
536
+ <div style="display: inline-block; width: 100px; height: 8px; background: #e9ecef; border-radius: 4px;">
537
+ <div style="width: {conf*100}%; height: 100%; background: {bar_color}; border-radius: 4px;"></div>
538
+ </div>
539
+ <span style="margin-left: 8px;">{conf:.1%}</span>
540
+ </div>
541
+ </div>
542
+ """
543
+ results['classification'] = class_result
544
+ except Exception as e:
545
+ classifier_html = f"<div style='color: red;'>Classification error: {e}</div>"
546
+
547
+ # 2. OCR
548
+ ocr_text = ""
549
+ ocr_image = None
550
+ ocr_results = []
551
+ try:
552
+ if receipt_ocr:
553
+ ocr_results = receipt_ocr.extract_with_positions(image)
554
+ ocr_image = draw_ocr_boxes(image, ocr_results)
555
+
556
+ lines = [f"{i+1}. [{r['confidence']:.0%}] {r['text']}" for i, r in enumerate(ocr_results)]
557
+ ocr_text = f"Detected {len(ocr_results)} text regions:\n\n" + "\n".join(lines)
558
+ results['ocr'] = ocr_results
559
+ except Exception as e:
560
+ ocr_text = f"OCR error: {e}"
561
+
562
+ # 3. Field Extraction
563
+ fields = {}
564
+ fields_html = ""
565
+ try:
566
+ if receipt_ocr and ocr_results:
567
+ fields = receipt_ocr.postprocess_receipt(ocr_results)
568
+
569
+ fields_html = "<div style='padding: 16px; background: #f8f9fa; border-radius: 12px;'><h4>Extracted Fields</h4>"
570
+ for name, value in [('Vendor', fields.get('vendor')), ('Date', fields.get('date')),
571
+ ('Total', f"${fields.get('total')}" if fields.get('total') else None),
572
+ ('Time', fields.get('time'))]:
573
+ display = value or '<span style="color: #adb5bd;">Not found</span>'
574
+ fields_html += f"<div style='padding: 8px; background: white; border-radius: 6px; margin: 4px 0;'><b>{name}:</b> {display}</div>"
575
+ fields_html += "</div>"
576
+ results['fields'] = fields
577
+ except Exception as e:
578
+ fields_html = f"<div style='color: red;'>Extraction error: {e}</div>"
579
+
580
+ # 4. Anomaly Detection
581
+ anomaly_html = ""
582
+ try:
583
+ anomaly_result = anomaly_detector.predict(fields)
584
+ status_color = '#dc3545' if anomaly_result['is_anomaly'] else '#28a745'
585
+ status_text = anomaly_result['prediction']
586
+
587
+ anomaly_html = f"""
588
+ <div style="padding: 16px; background: #f8f9fa; border-radius: 12px; margin: 8px 0;">
589
+ <h4 style="margin: 0 0 8px 0;">Anomaly Detection</h4>
590
+ <div style="font-size: 18px; font-weight: bold; color: {status_color};">{status_text}</div>
591
+ """
592
+ if anomaly_result['reasons']:
593
+ anomaly_html += "<ul style='margin: 8px 0; padding-left: 20px;'>"
594
+ for reason in anomaly_result['reasons']:
595
+ anomaly_html += f"<li>{reason}</li>"
596
+ anomaly_html += "</ul>"
597
+ anomaly_html += "</div>"
598
+ results['anomaly'] = anomaly_result
599
+ except Exception as e:
600
+ anomaly_html = f"<div style='color: red;'>Anomaly detection error: {e}</div>"
601
+
602
+ # 5. Final Decision
603
+ is_receipt = results.get('classification', {}).get('is_receipt', True)
604
+ is_anomaly = results.get('anomaly', {}).get('is_anomaly', False)
605
+ conf = results.get('classification', {}).get('confidence', 0.5)
606
+
607
+ if not is_receipt:
608
+ decision = "REJECT"
609
+ decision_color = "#dc3545"
610
+ reason = "Not a receipt"
611
+ elif is_anomaly:
612
+ decision = "REVIEW"
613
+ decision_color = "#ffc107"
614
+ reason = "Anomaly detected"
615
+ elif conf < 0.7:
616
+ decision = "REVIEW"
617
+ decision_color = "#ffc107"
618
+ reason = "Low confidence"
619
+ else:
620
+ decision = "APPROVE"
621
+ decision_color = "#28a745"
622
+ reason = "All checks passed"
623
+
624
+ summary_html = f"""
625
+ <div style="padding: 24px; background: linear-gradient(135deg, {decision_color}22, {decision_color}11);
626
+ border-left: 4px solid {decision_color}; border-radius: 12px; text-align: center;">
627
+ <div style="font-size: 32px; font-weight: bold; color: {decision_color};">{decision}</div>
628
+ <div style="color: #6c757d; margin-top: 8px;">{reason}</div>
629
+ </div>
630
+ {classifier_html}
631
+ {anomaly_html}
632
+ {fields_html}
633
+ """
634
+
635
+ return summary_html, ocr_image, ocr_text, "", json.dumps(results, indent=2)
636
+
637
+
638
+ # ============================================================================
639
+ # Gradio Interface
640
+ # ============================================================================
641
+
642
+ CUSTOM_CSS = """
643
+ .gradio-container { max-width: 1200px !important; }
644
+ .main-header { text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
645
+ border-radius: 12px; color: white; margin-bottom: 20px; }
646
+ """
647
+
648
+ with gr.Blocks(title="Receipt Processing Agent", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
649
+ gr.Markdown("""
650
+ <div class="main-header">
651
+ <h1>Receipt Processing Agent</h1>
652
+ <p>Ensemble classification, OCR, field extraction, and anomaly detection</p>
653
+ </div>
654
+ """)
655
+
656
+ gr.Markdown("""
657
+ ### How It Works
658
+ Upload a receipt image to automatically:
659
+ - **Classify** document type with ViT + ResNet ensemble
660
+ - **Extract text** with EasyOCR (with bounding boxes)
661
+ - **Extract fields** (vendor, date, total) using regex patterns
662
+ - **Detect anomalies** with rule-based checks
663
+ - **Route** to APPROVE / REVIEW / REJECT
664
+
665
+ ---
666
+ """)
667
+
668
+ with gr.Row():
669
+ with gr.Column(scale=1):
670
+ gr.Markdown("### Upload Receipt")
671
+ input_image = gr.Image(type="pil", label="Receipt Image", height=350)
672
+ process_btn = gr.Button("Process Receipt", variant="primary", size="lg")
673
+
674
+ with gr.Column(scale=1):
675
+ agent_summary = gr.HTML(
676
+ label="Results",
677
+ value="<div style='padding: 40px; text-align: center; color: #6c757d;'>Upload an image to begin</div>"
678
+ )
679
+
680
+ with gr.Accordion("OCR Results", open=False):
681
+ with gr.Row():
682
+ ocr_image_output = gr.Image(label="Detected Text Regions", height=300)
683
+ ocr_text_output = gr.Textbox(label="Extracted Text", lines=12)
684
+
685
+ with gr.Accordion("Raw Results (JSON)", open=False):
686
+ results_json = gr.Textbox(label="Full Results", lines=15)
687
+
688
+ hidden_state = gr.Textbox(visible=False)
689
+
690
+ process_btn.click(
691
+ fn=process_receipt,
692
+ inputs=[input_image],
693
+ outputs=[agent_summary, ocr_image_output, ocr_text_output, hidden_state, results_json]
694
+ )
695
+
696
+ gr.Markdown("""
697
+ ---
698
+ ### About This Demo
699
+
700
+ This is a simplified version of the full pipeline for demonstration purposes.
701
+ The complete system includes:
702
+ - LayoutLMv3 for advanced field extraction
703
+ - 4-model anomaly detection ensemble
704
+ - LangGraph agentic workflow
705
+ - Human feedback loop with model fine-tuning
706
+
707
+ **Repository**: [GitHub](https://github.com/RogueTex/StreamingDataforModelTraining)
708
+ """)
709
+
710
+
711
+ # Launch
712
+ if __name__ == "__main__":
713
+ demo.launch()
714
+
models/anomaly_detector.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6494401618de7e649b003289c78f817fa9c15498054aae1fca6aa6417415a8b
3
+ size 1592483
models/layoutlm_extractor.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5df280d9c5a94131b56f81a459774c51fe2fc21d79cb585c275b71919a8f1075
3
+ size 501421255
models/model_summary.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models_dir": "/Users/shruthisubramanian/Downloads/models",
3
+ "models": {
4
+ "rvl_classifier.pt": "ViT-based document classifier (receipt vs other)",
5
+ "layoutlm_extractor.pt": "LayoutLMv3 for field extraction (vendor/date/total)",
6
+ "anomaly_detector.pt": "Isolation Forest for anomaly detection"
7
+ },
8
+ "pipeline": {
9
+ "nodes": [
10
+ "ingest",
11
+ "classify",
12
+ "ocr",
13
+ "extract",
14
+ "anomaly",
15
+ "route"
16
+ ],
17
+ "framework": "LangGraph"
18
+ },
19
+ "metrics": {
20
+ "num_samples": 20,
21
+ "ocr_accuracy": 0.5333333333333333,
22
+ "vendor_accuracy": 0.75,
23
+ "date_accuracy": 0.85,
24
+ "total_accuracy": 0.0,
25
+ "extraction_f1": 0.6956521739130436,
26
+ "straight_through_rate": 0.1,
27
+ "review_rate": 0.9,
28
+ "reject_rate": 0.0,
29
+ "avg_processing_time": 1.0378559350967407,
30
+ "error_rate": 0.0
31
+ }
32
+ }
models/rvl_10k.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b52d495591bad54cf579cffb377aa28d643906c0600471516b5b83351c28d2d
3
+ size 44793291
models/rvl_classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aec5f448d1970b7ede6fa2dd8d50d98dcb9065bdc0561c295a3e0bf216d800b5
3
+ size 22180625
models/rvl_resnet18.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0756c88879133fc5678cc6428e4d9805cc72ced30dc4433a85bbb2dc89b0a15d
3
+ size 44819147
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ transformers>=4.30.0
4
+ easyocr>=1.7.0
5
+ gradio>=4.0.0
6
+ Pillow>=9.0.0
7
+ numpy>=1.21.0
8
+ scikit-learn>=1.0.0
9
+ opencv-python-headless>=4.5.0
10
+