Finish-him commited on
Commit
e302caf
·
verified ·
1 Parent(s): a750daa

delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -117
train.py DELETED
@@ -1,117 +0,0 @@
1
- import os
2
- import glob
3
- import json
4
- import csv
5
- import numpy as np
6
- from tqdm.auto import tqdm
7
- from sentence_transformers import SentenceTransformer
8
- import zipfile
9
- import xml.etree.ElementTree as ET
10
-
11
- DATA_DIR = "/app/dados"
12
- EXTRACT_DIR = "/app/dados_extraidos"
13
-
14
- def setup_data():
15
- os.makedirs(EXTRACT_DIR, exist_ok=True)
16
- zip_files = glob.glob(DATA_DIR + "/**/*.zip", recursive=True)
17
- if not zip_files:
18
- print("Nenhum arquivo .zip encontrado, usando o diretório de dados principal.")
19
- return DATA_DIR
20
- for zip_path in zip_files:
21
- print(f"Descompactando {zip_path}...")
22
- with zipfile.ZipFile(zip_path, 'r') as zf:
23
- zf.extractall(EXTRACT_DIR)
24
- return EXTRACT_DIR
25
-
26
- def xml_to_dict(element):
27
- d = {}
28
- for child in element:
29
- child_dict = xml_to_dict(child)
30
- if child.tag in d:
31
- if not isinstance(d[child.tag], list):
32
- d[child.tag] = [d[child.tag]]
33
- d[child.tag].append(child_dict)
34
- else:
35
- d[child.tag] = child_dict
36
- if not d:
37
- return element.text
38
- return d
39
-
40
- def serialize_item_to_text(item_dict):
41
- parts = []
42
- if not isinstance(item_dict, dict):
43
- return str(item_dict)
44
- for key, value in item_dict.items():
45
- if isinstance(value, dict):
46
- nested_text = serialize_item_to_text(value)
47
- parts.append(f"{key} ({nested_text})")
48
- elif isinstance(value, list):
49
- list_str = ', '.join([serialize_item_to_text(i) for i in value])
50
- parts.append(f"{key}: [{list_str}]")
51
- else:
52
- parts.append(f"{key}: {value}")
53
- return ", ".join(parts)
54
-
55
- def main():
56
- process_dir = setup_data()
57
- csv.field_size_limit(10_000_000)
58
-
59
- all_files = glob.glob(process_dir + "/**/*.json", recursive=True) + \
60
- glob.glob(process_dir + "/**/*.csv", recursive=True) + \
61
- glob.glob(process_dir + "/**/*.xml", recursive=True)
62
- print(f"\n🔎 Encontrados {len(all_files)} arquivos (JSON, CSV, XML) para processar.")
63
-
64
- if not all_files:
65
- return
66
-
67
- documents = []
68
- for filepath in tqdm(all_files, desc="Processando arquivos"):
69
- try:
70
- if filepath.endswith('.json'):
71
- with open(filepath, 'r', encoding='utf-8') as f:
72
- data = json.load(f)
73
- if isinstance(data, list):
74
- for item in data: documents.append(serialize_item_to_text(item))
75
- else:
76
- documents.append(serialize_item_to_text(data))
77
- elif filepath.endswith('.csv'):
78
- with open(filepath, 'r', encoding='utf-8') as f:
79
- reader = csv.DictReader(f)
80
- for row in reader: documents.append(serialize_item_to_text(row))
81
- elif filepath.endswith('.xml'):
82
- tree = ET.parse(filepath)
83
- root = tree.getroot()
84
- xml_dict = {root.tag: xml_to_dict(root)}
85
- documents.append(serialize_item_to_text(xml_dict))
86
- except Exception as e:
87
- print(f"⚠️ Erro ao processar o arquivo {filepath}: {e}")
88
-
89
- print(f"\nProcessamento de arquivos concluído! {len(documents)} documentos foram criados.")
90
- if not documents:
91
- return
92
-
93
- cache_path = os.environ.get('SENTENCE_TRANSFORMERS_HOME', '/app/cache/torch')
94
-
95
- print("Carregando modelo de alta performance: intfloat/multilingual-e5-large")
96
- model = SentenceTransformer(
97
- 'intfloat/multilingual-e5-large',
98
- cache_folder=cache_path
99
- )
100
-
101
- batch_size = 128
102
- output_filename = '/app/output/meus_embeddings_e5_large.npy'
103
-
104
- if os.path.exists(output_filename):
105
- os.remove(output_filename)
106
-
107
- print(f"🚀 Iniciando geração de embeddings (lotes de {batch_size}).")
108
- for i in tqdm(range(0, len(documents), batch_size), desc="Gerando Embeddings"):
109
- batch = documents[i:i + batch_size]
110
- batch_embeddings = model.encode(batch, show_progress_bar=False)
111
- with open(output_filename, 'ab') as f_out:
112
- np.save(f_out, batch_embeddings)
113
-
114
- print(f"✅ Processo finalizado! Embeddings salvos em '{output_filename}'.")
115
-
116
- if __name__ == "__main__":
117
- main()