Spaces:
Sleeping
Sleeping
Commit ·
62a231e
1
Parent(s): 7351ab3
minor fixes
Browse files- document_classifier.py +25 -19
- rag_setup.py +0 -1
- requirements.txt +1 -2
- user_data.py +0 -35
document_classifier.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
from langchain_community.document_loaders import PyPDFLoader
|
| 2 |
from transformers import pipeline
|
| 3 |
import torch
|
| 4 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
from collections import defaultdict
|
| 6 |
import time
|
| 7 |
|
|
@@ -23,12 +22,10 @@ class DocumentClassifier:
|
|
| 23 |
self,
|
| 24 |
pages_per_group=2,
|
| 25 |
min_confidence=0.35,
|
| 26 |
-
max_workers=4,
|
| 27 |
model_name="cross-encoder/nli-deberta-v3-small"
|
| 28 |
):
|
| 29 |
self.pages_per_group = pages_per_group
|
| 30 |
self.min_confidence = min_confidence
|
| 31 |
-
self.max_workers = max_workers
|
| 32 |
self.model_name = model_name
|
| 33 |
self.classifier = None
|
| 34 |
|
|
@@ -111,25 +108,31 @@ class DocumentClassifier:
|
|
| 111 |
|
| 112 |
def _classify_groups_parallel(self, groups):
|
| 113 |
results = []
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
return results
|
| 131 |
|
| 132 |
def _classify_single_group(self, group):
|
|
|
|
| 133 |
text = group['text']
|
| 134 |
|
| 135 |
if not text.strip():
|
|
@@ -139,6 +142,9 @@ class DocumentClassifier:
|
|
| 139 |
|
| 140 |
primary_type = result['labels'][0]
|
| 141 |
primary_score = result['scores'][0]
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
scores = {
|
| 144 |
label: score
|
|
|
|
| 1 |
from langchain_community.document_loaders import PyPDFLoader
|
| 2 |
from transformers import pipeline
|
| 3 |
import torch
|
|
|
|
| 4 |
from collections import defaultdict
|
| 5 |
import time
|
| 6 |
|
|
|
|
| 22 |
self,
|
| 23 |
pages_per_group=2,
|
| 24 |
min_confidence=0.35,
|
|
|
|
| 25 |
model_name="cross-encoder/nli-deberta-v3-small"
|
| 26 |
):
|
| 27 |
self.pages_per_group = pages_per_group
|
| 28 |
self.min_confidence = min_confidence
|
|
|
|
| 29 |
self.model_name = model_name
|
| 30 |
self.classifier = None
|
| 31 |
|
|
|
|
| 108 |
|
| 109 |
def _classify_groups_parallel(self, groups):
|
| 110 |
results = []
|
| 111 |
+
texts = [g['text'] for g in groups]
|
| 112 |
+
|
| 113 |
+
# Use pipeline's native batching — faster than ThreadPoolExecutor,
|
| 114 |
+
# especially on GPU, and avoids thread-safety issues with PyTorch.
|
| 115 |
+
batch_results = self.classifier(texts, self.LABELS, multi_label=True, batch_size=8)
|
| 116 |
+
|
| 117 |
+
for group, result in zip(groups, batch_results):
|
| 118 |
+
primary_type = result['labels'][0]
|
| 119 |
+
primary_score = result['scores'][0]
|
| 120 |
+
|
| 121 |
+
if primary_score < self.min_confidence:
|
| 122 |
+
primary_type = 'other'
|
| 123 |
+
|
| 124 |
+
scores = {label: score for label, score in zip(result['labels'], result['scores'])}
|
| 125 |
+
results.append({
|
| 126 |
+
'type': primary_type,
|
| 127 |
+
'confidence': primary_score,
|
| 128 |
+
'scores': scores,
|
| 129 |
+
'page_numbers': group['page_numbers']
|
| 130 |
+
})
|
| 131 |
+
|
| 132 |
return results
|
| 133 |
|
| 134 |
def _classify_single_group(self, group):
|
| 135 |
+
# Kept for single-group use if needed directly
|
| 136 |
text = group['text']
|
| 137 |
|
| 138 |
if not text.strip():
|
|
|
|
| 142 |
|
| 143 |
primary_type = result['labels'][0]
|
| 144 |
primary_score = result['scores'][0]
|
| 145 |
+
|
| 146 |
+
if primary_score < self.min_confidence:
|
| 147 |
+
primary_type = 'other'
|
| 148 |
|
| 149 |
scores = {
|
| 150 |
label: score
|
rag_setup.py
CHANGED
|
@@ -18,7 +18,6 @@ class RAG_Setup:
|
|
| 18 |
self.classifier = DocumentClassifier(
|
| 19 |
pages_per_group=2,
|
| 20 |
min_confidence=0.35,
|
| 21 |
-
max_workers=4,
|
| 22 |
model_name="cross-encoder/nli-deberta-v3-small"
|
| 23 |
)
|
| 24 |
|
|
|
|
| 18 |
self.classifier = DocumentClassifier(
|
| 19 |
pages_per_group=2,
|
| 20 |
min_confidence=0.35,
|
|
|
|
| 21 |
model_name="cross-encoder/nli-deberta-v3-small"
|
| 22 |
)
|
| 23 |
|
requirements.txt
CHANGED
|
@@ -8,5 +8,4 @@ transformers
|
|
| 8 |
sentence-transformers
|
| 9 |
torch
|
| 10 |
pypdf
|
| 11 |
-
gradio
|
| 12 |
-
pyaudioop
|
|
|
|
| 8 |
sentence-transformers
|
| 9 |
torch
|
| 10 |
pypdf
|
| 11 |
+
gradio
|
|
|
user_data.py
CHANGED
|
@@ -25,12 +25,6 @@ def initialize_db():
|
|
| 25 |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 26 |
FOREIGN KEY (user_id) REFERENCES users(id)
|
| 27 |
);
|
| 28 |
-
|
| 29 |
-
CREATE TABLE IF NOT EXISTS document_classifications (
|
| 30 |
-
file_hash TEXT PRIMARY KEY,
|
| 31 |
-
doc_type TEXT NOT NULL,
|
| 32 |
-
classified_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 33 |
-
);
|
| 34 |
""")
|
| 35 |
conn.commit()
|
| 36 |
conn.close()
|
|
@@ -57,32 +51,3 @@ def user_exists(user_id):
|
|
| 57 |
exists = cursor.fetchone() is not None
|
| 58 |
conn.close()
|
| 59 |
return exists
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def get_document_label(file_hash: str):
|
| 63 |
-
conn = get_connection()
|
| 64 |
-
cursor = conn.cursor()
|
| 65 |
-
cursor.execute(
|
| 66 |
-
"SELECT doc_type FROM document_classifications WHERE file_hash = ?",
|
| 67 |
-
(file_hash,)
|
| 68 |
-
)
|
| 69 |
-
row = cursor.fetchone()
|
| 70 |
-
conn.close()
|
| 71 |
-
return row["doc_type"] if row else None
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def save_document_label(file_hash: str, doc_type: str):
|
| 75 |
-
conn = get_connection()
|
| 76 |
-
cursor = conn.cursor()
|
| 77 |
-
cursor.execute(
|
| 78 |
-
"""
|
| 79 |
-
INSERT INTO document_classifications (file_hash, doc_type)
|
| 80 |
-
VALUES (?, ?)
|
| 81 |
-
ON CONFLICT(file_hash) DO UPDATE SET
|
| 82 |
-
doc_type = excluded.doc_type,
|
| 83 |
-
classified_at = CURRENT_TIMESTAMP
|
| 84 |
-
""",
|
| 85 |
-
(file_hash, doc_type)
|
| 86 |
-
)
|
| 87 |
-
conn.commit()
|
| 88 |
-
conn.close()
|
|
|
|
| 25 |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 26 |
FOREIGN KEY (user_id) REFERENCES users(id)
|
| 27 |
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
""")
|
| 29 |
conn.commit()
|
| 30 |
conn.close()
|
|
|
|
| 51 |
exists = cursor.fetchone() is not None
|
| 52 |
conn.close()
|
| 53 |
return exists
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|