twissamodi commited on
Commit
62a231e
·
1 Parent(s): 7351ab3

minor fixes

Browse files
Files changed (4) hide show
  1. document_classifier.py +25 -19
  2. rag_setup.py +0 -1
  3. requirements.txt +1 -2
  4. user_data.py +0 -35
document_classifier.py CHANGED
@@ -1,7 +1,6 @@
1
  from langchain_community.document_loaders import PyPDFLoader
2
  from transformers import pipeline
3
  import torch
4
- from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from collections import defaultdict
6
  import time
7
 
@@ -23,12 +22,10 @@ class DocumentClassifier:
23
  self,
24
  pages_per_group=2,
25
  min_confidence=0.35,
26
- max_workers=4,
27
  model_name="cross-encoder/nli-deberta-v3-small"
28
  ):
29
  self.pages_per_group = pages_per_group
30
  self.min_confidence = min_confidence
31
- self.max_workers = max_workers
32
  self.model_name = model_name
33
  self.classifier = None
34
 
@@ -111,25 +108,31 @@ class DocumentClassifier:
111
 
112
  def _classify_groups_parallel(self, groups):
113
  results = []
114
-
115
- with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
116
- future_to_group = {
117
- executor.submit(self._classify_single_group, group): group
118
- for group in groups
119
- }
120
-
121
- for future in as_completed(future_to_group):
122
- group = future_to_group[future]
123
- try:
124
- result = future.result()
125
- result['page_numbers'] = group['page_numbers']
126
- results.append(result)
127
- except Exception as e:
128
- print(f"[Classifier] Group classification failed: {e}")
129
-
 
 
 
 
 
130
  return results
131
 
132
  def _classify_single_group(self, group):
 
133
  text = group['text']
134
 
135
  if not text.strip():
@@ -139,6 +142,9 @@ class DocumentClassifier:
139
 
140
  primary_type = result['labels'][0]
141
  primary_score = result['scores'][0]
 
 
 
142
 
143
  scores = {
144
  label: score
 
1
  from langchain_community.document_loaders import PyPDFLoader
2
  from transformers import pipeline
3
  import torch
 
4
  from collections import defaultdict
5
  import time
6
 
 
22
  self,
23
  pages_per_group=2,
24
  min_confidence=0.35,
 
25
  model_name="cross-encoder/nli-deberta-v3-small"
26
  ):
27
  self.pages_per_group = pages_per_group
28
  self.min_confidence = min_confidence
 
29
  self.model_name = model_name
30
  self.classifier = None
31
 
 
108
 
109
  def _classify_groups_parallel(self, groups):
110
  results = []
111
+ texts = [g['text'] for g in groups]
112
+
113
+ # Use pipeline's native batching — faster than ThreadPoolExecutor,
114
+ # especially on GPU, and avoids thread-safety issues with PyTorch.
115
+ batch_results = self.classifier(texts, self.LABELS, multi_label=True, batch_size=8)
116
+
117
+ for group, result in zip(groups, batch_results):
118
+ primary_type = result['labels'][0]
119
+ primary_score = result['scores'][0]
120
+
121
+ if primary_score < self.min_confidence:
122
+ primary_type = 'other'
123
+
124
+ scores = {label: score for label, score in zip(result['labels'], result['scores'])}
125
+ results.append({
126
+ 'type': primary_type,
127
+ 'confidence': primary_score,
128
+ 'scores': scores,
129
+ 'page_numbers': group['page_numbers']
130
+ })
131
+
132
  return results
133
 
134
  def _classify_single_group(self, group):
135
+ # Kept for single-group use if needed directly
136
  text = group['text']
137
 
138
  if not text.strip():
 
142
 
143
  primary_type = result['labels'][0]
144
  primary_score = result['scores'][0]
145
+
146
+ if primary_score < self.min_confidence:
147
+ primary_type = 'other'
148
 
149
  scores = {
150
  label: score
rag_setup.py CHANGED
@@ -18,7 +18,6 @@ class RAG_Setup:
18
  self.classifier = DocumentClassifier(
19
  pages_per_group=2,
20
  min_confidence=0.35,
21
- max_workers=4,
22
  model_name="cross-encoder/nli-deberta-v3-small"
23
  )
24
 
 
18
  self.classifier = DocumentClassifier(
19
  pages_per_group=2,
20
  min_confidence=0.35,
 
21
  model_name="cross-encoder/nli-deberta-v3-small"
22
  )
23
 
requirements.txt CHANGED
@@ -8,5 +8,4 @@ transformers
8
  sentence-transformers
9
  torch
10
  pypdf
11
- gradio
12
- pyaudioop
 
8
  sentence-transformers
9
  torch
10
  pypdf
11
+ gradio
 
user_data.py CHANGED
@@ -25,12 +25,6 @@ def initialize_db():
25
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
26
  FOREIGN KEY (user_id) REFERENCES users(id)
27
  );
28
-
29
- CREATE TABLE IF NOT EXISTS document_classifications (
30
- file_hash TEXT PRIMARY KEY,
31
- doc_type TEXT NOT NULL,
32
- classified_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
33
- );
34
  """)
35
  conn.commit()
36
  conn.close()
@@ -57,32 +51,3 @@ def user_exists(user_id):
57
  exists = cursor.fetchone() is not None
58
  conn.close()
59
  return exists
60
-
61
-
62
- def get_document_label(file_hash: str):
63
- conn = get_connection()
64
- cursor = conn.cursor()
65
- cursor.execute(
66
- "SELECT doc_type FROM document_classifications WHERE file_hash = ?",
67
- (file_hash,)
68
- )
69
- row = cursor.fetchone()
70
- conn.close()
71
- return row["doc_type"] if row else None
72
-
73
-
74
- def save_document_label(file_hash: str, doc_type: str):
75
- conn = get_connection()
76
- cursor = conn.cursor()
77
- cursor.execute(
78
- """
79
- INSERT INTO document_classifications (file_hash, doc_type)
80
- VALUES (?, ?)
81
- ON CONFLICT(file_hash) DO UPDATE SET
82
- doc_type = excluded.doc_type,
83
- classified_at = CURRENT_TIMESTAMP
84
- """,
85
- (file_hash, doc_type)
86
- )
87
- conn.commit()
88
- conn.close()
 
25
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
26
  FOREIGN KEY (user_id) REFERENCES users(id)
27
  );
 
 
 
 
 
 
28
  """)
29
  conn.commit()
30
  conn.close()
 
51
  exists = cursor.fetchone() is not None
52
  conn.close()
53
  return exists