Finish-him commited on
Commit
70dec30
·
verified ·
1 Parent(s): b5c2272
Files changed (1) hide show
  1. train.py +65 -20
train.py CHANGED
@@ -1,7 +1,6 @@
1
- # =================== CÓDIGO COMPLETO E ATUALIZADO ===================
2
 
3
  import os
4
- import zipfile
5
  import glob
6
  import json
7
  import csv
@@ -10,24 +9,10 @@ from tqdm.auto import tqdm
10
  from sentence_transformers import SentenceTransformer
11
 
12
  # --- CONFIGURAÇÕES ---
13
- # Lembre-se de colocar o nome correto do seu arquivo .zip!
14
- ZIP_FILENAME = "seus-dados.zip"
15
- EXTRACT_DIR = "/app/dados_extraidos"
16
  # ---------------------
17
 
18
- def setup_data():
19
- """Descompacta os dados se o diretório não existir."""
20
- if not os.path.exists(EXTRACT_DIR) and os.path.exists(ZIP_FILENAME):
21
- print(f"Descompactando '{ZIP_FILENAME}'...")
22
- os.makedirs(EXTRACT_DIR, exist_ok=True)
23
- with zipfile.ZipFile(ZIP_FILENAME, 'r') as zip_ref:
24
- zip_ref.extractall(EXTRACT_DIR)
25
- print("✅ Dados descompactados.")
26
- elif not os.path.exists(ZIP_FILENAME):
27
- print(f"⚠️ Arquivo '{ZIP_FILENAME}' não encontrado. Pulando descompactação.")
28
- else:
29
- print("✅ Dados já parecem estar descompactados.")
30
-
31
  def serialize_item_to_text(item_dict):
32
  """Converte um dicionário em uma string de texto."""
33
  parts = []
@@ -44,8 +29,68 @@ def serialize_item_to_text(item_dict):
44
 
45
  def main():
46
  """Função principal para carregar dados e gerar embeddings."""
47
- setup_data()
48
  csv.field_size_limit(10_000_000)
49
 
 
 
 
 
 
 
 
 
 
 
50
  documents = []
51
- all_files = glob.glob(EXTRACT_DIR + "/**/*.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =================== CÓDIGO COMPLETO E FINAL ===================
2
 
3
  import os
 
4
  import glob
5
  import json
6
  import csv
 
9
  from sentence_transformers import SentenceTransformer
10
 
11
  # --- CONFIGURAÇÕES ---
12
+ # O diretório onde o Dockerfile clonou os dados do próprio Space
13
+ DATA_DIR = "/app/dados"
 
14
  # ---------------------
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def serialize_item_to_text(item_dict):
17
  """Converte um dicionário em uma string de texto."""
18
  parts = []
 
29
 
30
  def main():
31
  """Função principal para carregar dados e gerar embeddings."""
 
32
  csv.field_size_limit(10_000_000)
33
 
34
+ # Procura recursivamente por todos os arquivos .json e .csv no diretório de dados
35
+ all_files = glob.glob(DATA_DIR + "/**/*.json", recursive=True) + \
36
+ glob.glob(DATA_DIR + "/**/*.csv", recursive=True)
37
+
38
+ print(f"🔎 Encontrados {len(all_files)} arquivos para processar no repositório.")
39
+
40
+ if not all_files:
41
+ print("⚠️ Nenhum arquivo .csv ou .json encontrado. Verifique se os dados estão no repositório.")
42
+ return
43
+
44
  documents = []
45
+ for filepath in all_files:
46
+ try:
47
+ if filepath.endswith('.json'):
48
+ with open(filepath, 'r', encoding='utf-8') as f:
49
+ data = json.load(f)
50
+ if isinstance(data, list):
51
+ for item in data: documents.append(serialize_item_to_text(item))
52
+ else:
53
+ documents.append(serialize_item_to_text(data))
54
+ elif filepath.endswith('.csv'):
55
+ with open(filepath, 'r', encoding='utf-8') as f:
56
+ reader = csv.DictReader(f)
57
+ for row in reader: documents.append(serialize_item_to_text(row))
58
+ except Exception as e:
59
+ print(f"⚠️ Erro ao processar o arquivo {filepath}: {e}")
60
+
61
+ print(f"\nProcessamento de arquivos concluído! {len(documents)} documentos foram criados.")
62
+
63
+ if not documents:
64
+ print("Nenhum documento foi lido com sucesso. Encerrando.")
65
+ return
66
+
67
+ # Define o caminho do cache e carrega o modelo
68
+ cache_path = os.environ.get('SENTENCE_TRANSFORMERS_HOME', '/app/cache/torch')
69
+
70
+ print("Carregando modelo avançado: intfloat/e5-mistral-7b-instruct")
71
+ print("Isso pode levar vários minutos, pois o modelo é grande.")
72
+ model = SentenceTransformer(
73
+ 'intfloat/e5-mistral-7b-instruct',
74
+ cache_folder=cache_path,
75
+ trust_remote_code=True
76
+ )
77
+
78
+ batch_size = 64
79
+ output_filename = 'meus_embeddings_finais.npy'
80
+
81
+ if os.path.exists(output_filename):
82
+ os.remove(output_filename)
83
+
84
+ print(f"🚀 Iniciando geração de embeddings (lotes de {batch_size}).")
85
+ for i in tqdm(range(0, len(documents), batch_size)):
86
+ batch = documents[i:i+batch_size]
87
+ batch_embeddings = model.encode(batch, show_progress_bar=False)
88
+ with open(output_filename, 'ab') as f_out:
89
+ np.save(f_out, batch_embeddings)
90
+
91
+ print(f"✅ Processo finalizado! Embeddings salvos em '{output_filename}'.")
92
+
93
+ if __name__ == "__main__":
94
+ main()
95
+
96
+ # =================================================================