Finish-him's picture
Update train.py
ca3769d verified
raw
history blame
4.24 kB
import os
import glob
import json
import csv
import numpy as np
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
import zipfile
import xml.etree.ElementTree as ET
DATA_DIR = "/app/dados"
EXTRACT_DIR = "/app/dados_extraidos"
def setup_data():
os.makedirs(EXTRACT_DIR, exist_ok=True)
zip_files = glob.glob(DATA_DIR + "/**/*.zip", recursive=True)
if not zip_files:
print("Nenhum arquivo .zip encontrado, usando o diretório de dados principal.")
return DATA_DIR
for zip_path in zip_files:
print(f"Descompactando {zip_path}...")
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall(EXTRACT_DIR)
return EXTRACT_DIR
def xml_to_dict(element):
d = {}
for child in element:
child_dict = xml_to_dict(child)
if child.tag in d:
if not isinstance(d[child.tag], list):
d[child.tag] = [d[child.tag]]
d[child.tag].append(child_dict)
else:
d[child.tag] = child_dict
if not d:
return element.text
return d
def serialize_item_to_text(item_dict):
parts = []
if not isinstance(item_dict, dict):
return str(item_dict)
for key, value in item_dict.items():
if isinstance(value, dict):
nested_text = serialize_item_to_text(value)
parts.append(f"{key} ({nested_text})")
elif isinstance(value, list):
list_str = ', '.join([serialize_item_to_text(i) for i in value])
parts.append(f"{key}: [{list_str}]")
else:
parts.append(f"{key}: {value}")
return ", ".join(parts)
def main():
process_dir = setup_data()
csv.field_size_limit(10_000_000)
all_files = glob.glob(process_dir + "/**/*.json", recursive=True) + \
glob.glob(process_dir + "/**/*.csv", recursive=True) + \
glob.glob(process_dir + "/**/*.xml", recursive=True)
print(f"\n🔎 Encontrados {len(all_files)} arquivos (JSON, CSV, XML) para processar.")
if not all_files:
return
documents = []
for filepath in tqdm(all_files, desc="Processando arquivos"):
try:
if filepath.endswith('.json'):
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
if isinstance(data, list):
for item in data: documents.append(serialize_item_to_text(item))
else:
documents.append(serialize_item_to_text(data))
elif filepath.endswith('.csv'):
with open(filepath, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader: documents.append(serialize_item_to_text(row))
elif filepath.endswith('.xml'):
tree = ET.parse(filepath)
root = tree.getroot()
xml_dict = {root.tag: xml_to_dict(root)}
documents.append(serialize_item_to_text(xml_dict))
except Exception as e:
print(f"⚠️ Erro ao processar o arquivo {filepath}: {e}")
print(f"\nProcessamento de arquivos concluído! {len(documents)} documentos foram criados.")
if not documents:
return
cache_path = os.environ.get('SENTENCE_TRANSFORMERS_HOME', '/app/cache/torch')
print("Carregando modelo de alta performance: intfloat/multilingual-e5-large")
model = SentenceTransformer(
'intfloat/multilingual-e5-large',
cache_folder=cache_path
)
batch_size = 128
output_filename = '/app/output/meus_embeddings_e5_large.npy'
if os.path.exists(output_filename):
os.remove(output_filename)
print(f"🚀 Iniciando geração de embeddings (lotes de {batch_size}).")
for i in tqdm(range(0, len(documents), batch_size), desc="Gerando Embeddings"):
batch = documents[i:i + batch_size]
batch_embeddings = model.encode(batch, show_progress_bar=False)
with open(output_filename, 'ab') as f_out:
np.save(f_out, batch_embeddings)
print(f"✅ Processo finalizado! Embeddings salvos em '{output_filename}'.")
if __name__ == "__main__":
main()