pemix09 commited on
Commit
e1f62c6
·
verified ·
1 Parent(s): 43b52cb

Upload folder using huggingface_hub

Browse files
document_type_classifier.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d45cc00b8f10c5e460abb01cb7baa797682e9ee15ee88a9365b478529fd84550
3
+ size 540769452
document_type_classifierlearn.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import json
5
+ from pathlib import Path
6
+ from collections import Counter
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.metrics import classification_report
9
+ from sklearn.preprocessing import LabelEncoder
10
+ from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
11
+
12
+ # --- KONFIGURACJA ŚCIEŻEK ---
13
+ # Skrypt jest w podfolderze, więc wychodzimy o jeden poziom wyżej (parent)
14
+ BASE_DIR = Path(__file__).resolve().parent.parent
15
+ DATA_ROOT = BASE_DIR / "content"
16
+ LABEL_ROOT = BASE_DIR / "type"
17
+
18
+ # Gdzie zapisać wyniki (możesz dostosować)
19
+ MODELS_DIR = Path(__file__).resolve().parent
20
+ TFLITE_OUTPUT = MODELS_DIR / "document_type_classifier.tflite"
21
+ LABELS_OUTPUT = MODELS_DIR / "document_type_labels.txt"
22
+
23
+ # Parametry modelu
24
+ MODEL_ID = "distilbert-base-multilingual-cased"
25
+ MIN_SAMPLES_PER_CLASS = 2
26
+ MAX_LEN = 256
27
+ BATCH_SIZE = 16
28
+ EPOCHS = 10 # Zwiększyłem dla lepszej skuteczności
29
+
30
+
31
+ def load_data():
32
+ texts, labels = [], []
33
+ if not DATA_ROOT.exists():
34
+ print(f"❌ BŁĄD: Nie znaleziono folderu content w: {DATA_ROOT}")
35
+ return [], []
36
+
37
+ print(f"📂 Wczytywanie danych z: {DATA_ROOT}")
38
+ for text_file in DATA_ROOT.rglob("*.txt"):
39
+ rel_path = text_file.relative_to(DATA_ROOT)
40
+ label_file = LABEL_ROOT / rel_path
41
+
42
+ if label_file.exists():
43
+ with open(text_file, "r", encoding="utf-8") as f:
44
+ content = f.read().strip()
45
+ with open(label_file, "r", encoding="utf-8") as f:
46
+ label = f.read().strip().lower()
47
+
48
+ if content and label:
49
+ texts.append(content)
50
+ labels.append(label)
51
+ return texts, labels
52
+
53
+
54
+ def main():
55
+ # 1. Ładowanie i filtrowanie
56
+ texts, labels = load_data()
57
+ if not texts: return
58
+
59
+ counts = Counter(labels)
60
+ valid_classes = [cls for cls, count in counts.items() if count >= MIN_SAMPLES_PER_CLASS]
61
+
62
+ filtered_texts, filtered_labels = [], []
63
+ for t, l in zip(texts, labels):
64
+ if l in valid_classes:
65
+ filtered_texts.append(t)
66
+ filtered_labels.append(l)
67
+
68
+ print(f"✅ Załadowano {len(filtered_texts)} dokumentów w {len(valid_classes)} kategoriach.")
69
+
70
+ # 2. Kodowanie etykiet
71
+ label_encoder = LabelEncoder()
72
+ y = label_encoder.fit_transform(filtered_labels)
73
+ num_labels = len(label_encoder.classes_)
74
+
75
+ with open(LABELS_OUTPUT, "w", encoding="utf-8") as f:
76
+ f.write("\n".join(label_encoder.classes_))
77
+
78
+ # 3. Podział na zbiory
79
+ train_texts, val_texts, train_labels, val_labels = train_test_split(
80
+ filtered_texts, y, test_size=0.20, random_state=42, stratify=y
81
+ )
82
+
83
+ # 4. Tokenizacja
84
+ tokenizer = DistilBertTokenizer.from_pretrained(MODEL_ID)
85
+
86
+ def tokenize_data(texts):
87
+ return tokenizer(
88
+ texts,
89
+ padding="max_length",
90
+ truncation=True,
91
+ max_length=MAX_LEN,
92
+ return_tensors="tf"
93
+ )
94
+
95
+ print("⏳ Tokenizacja danych...")
96
+ train_encodings = dict(tokenize_data(train_texts))
97
+ val_encodings = dict(tokenize_data(val_texts))
98
+
99
+ # 5. Budowanie modelu
100
+ print("🏗️ Inicjalizacja DistilBERT...")
101
+ model = TFDistilBertForSequenceClassification.from_pretrained(MODEL_ID, num_labels=num_labels, from_pt=True)
102
+
103
+ optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=3e-5)
104
+ loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
105
+ model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
106
+
107
+ # 6. Trenowanie
108
+ print("\n🚀 Start uczenia...")
109
+ model.fit(
110
+ x=train_encodings,
111
+ y=train_labels,
112
+ validation_data=(val_encodings, val_labels),
113
+ epochs=EPOCHS,
114
+ batch_size=BATCH_SIZE
115
+ )
116
+
117
+ # 7. Konwersja do TFLite (FIX: Kompatybilność wsteczna)
118
+ print("\n🔧 Konwersja do TFLite (Generowanie wersji kompatybilnej z Flutterem)...")
119
+
120
+ @tf.function(input_signature=[tf.TensorSpec([1, MAX_LEN], tf.int32, name="input_ids")])
121
+ def serving_fn(input_ids):
122
+ # training=False jest kluczowe dla stabilności opów
123
+ return model(input_ids, training=False)
124
+
125
+ converter = tf.lite.TFLiteConverter.from_concrete_functions(
126
+ [serving_fn.get_concrete_function()], model
127
+ )
128
+
129
+ # WYMUSZENIE KOMPATYBILNOŚCI:
130
+ # 1. Standardowe operatory
131
+ converter.target_spec.supported_ops = [
132
+ tf.lite.OpsSet.TFLITE_BUILTINS,
133
+ tf.lite.OpsSet.SELECT_TF_OPS
134
+ ]
135
+
136
+ # 2. Wyłączenie optymalizacji, która mogłaby podbić wersję opcode 'FULLY_CONNECTED' do 12
137
+ # Jeśli model będzie za duży, można spróbować przywrócić to po aktualizacji bibliotek we Flutterze
138
+ converter.optimizations = []
139
+
140
+ # 3. Wymuszenie formatu wyjściowego
141
+ converter.target_spec.supported_types = [tf.float32]
142
+
143
+ tflite_model = converter.convert()
144
+ with open(TFLITE_OUTPUT, "wb") as f:
145
+ f.write(tflite_model)
146
+
147
+ print(f"\n��� SUKCES!")
148
+ print(f"Model: {TFLITE_OUTPUT}")
149
+ print(f"Etykiety: {LABELS_OUTPUT}")
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
document_type_labels.txt ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ application
2
+ authorization
3
+ b2bcontract
4
+ bankstatement
5
+ birthcertificate
6
+ certificate
7
+ cit8
8
+ courtjudgment
9
+ cv
10
+ deathcertificate
11
+ documentscan
12
+ drivinglicense
13
+ employmentcontract
14
+ idcard
15
+ insurancepolicy
16
+ invoice
17
+ landmap
18
+ landregistry
19
+ lawsuit
20
+ loanagreement
21
+ mandatecontract
22
+ marriagecertificate
23
+ medicalhistory
24
+ medicalresults
25
+ noncompeteagreement
26
+ notarialdeed
27
+ other
28
+ passport
29
+ pcc3
30
+ peselconfirmation
31
+ pit11
32
+ pit28
33
+ pit36
34
+ pit36l
35
+ pit37
36
+ pit38
37
+ pit39
38
+ pit5
39
+ pit8c
40
+ powerofattorney
41
+ prescription
42
+ professionalcertificate
43
+ proformainvoice
44
+ propertydeed
45
+ receipt
46
+ referral
47
+ registrationcertificate
48
+ rentalagreement
49
+ sanitarybooklet
50
+ schoolcertificate
51
+ sickleave
52
+ taskcontract
53
+ technicalinspection
54
+ universitydiploma
55
+ utilitybill
56
+ vaccinationcard
57
+ vat7
58
+ vehiclehistory
59
+ vehicleregistration
labels.txt ADDED
File without changes