File size: 7,454 Bytes
6f54a86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
import os
import glob
import json
import pandas as pd
from sentence_transformers import SentenceTransformer
import numpy as np
from langchain_text_splitters import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain_community.embeddings import HuggingFaceEmbeddings
from . import base_utils as bu

def load_model(model_name):
    return SentenceTransformer(model_name, device="cpu")

def get_text_splitter(splitter, chunk_size, chunk_overlap):
    """

    Retrieve the appropriate text splitter based on a specified type.

    """
    if splitter == "recursive":
        return RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len,
        )
    elif splitter == "tokens":
        return CharacterTextSplitter.from_tiktoken_encoder(
            encoding_name="cl100k_base",
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
        )
    elif splitter == "semantic":
        embeddings_model = HuggingFaceEmbeddings(
            model_name=bu.config["embeddings"]["model_name"])
        return SemanticChunker(
            embeddings=embeddings_model,
        )
    else:
        return RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len
        )

def generate_embeddings(input_path, output_folder, model_name, splitter, chunk_size, chunk_overlap, retrieval_model, export_numpy=False, numpy_output_dir=None, max_files=None):
    text_splitter = get_text_splitter(splitter, chunk_size, chunk_overlap)
    md_files = glob.glob(input_path)
    if not md_files:
        print(f"No .md files found in path: {input_path}")
        return

    os.makedirs(output_folder, exist_ok=True)

    emb_files = glob.glob(os.path.join(output_folder, "*.h5"))
    for file in emb_files:
        filename_without_ext = os.path.splitext(os.path.basename(file))[0]
        corresponding_doc = os.path.join(os.path.dirname(input_path), filename_without_ext + ".md")
        if not os.path.exists(corresponding_doc):
            print(f"Embeddings file {file} has no corresponding .md. Deleting it.")
            os.remove(file)

    all_embeddings = []
    all_metadata = []
    global_idx = 0

    if max_files is not None:
        md_files = md_files[:max_files]

    total_files = len(md_files)
    for i, file in enumerate(md_files, start=1):
        file_name = os.path.basename(file)
        doc_id = os.path.splitext(file_name)[0]
        output_file = os.path.join(output_folder, f"{doc_id}.h5")

        if os.path.exists(output_file):
            print(f"Embeddings already exists for {file_name}. Skipping generation and loading existing file for export...")
            embeddings_df = pd.read_hdf(output_file, key="df")
        else:
            progress = (i / total_files) * 100
            print(f"[{i}/{total_files}] ({progress:.1f}%) Generating embeddings for: {file_name}")
            text = bu.load_md(file)

            embeddings_list = []
            content_list = []

            if text.strip():
                chunks = text_splitter.create_documents([text])
                print(f"Chunks generated for document {file_name} : {len(chunks)}")

                for chunk in chunks:
                    embedding = retrieval_model.encode(chunk.page_content)
                    embeddings_list.append(embedding)
                    content_list.append(chunk.page_content)

                embeddings_df = pd.DataFrame(embeddings_list)
                embeddings_df["segment_content"] = content_list
                embeddings_df["model_name"] = model_name
                embeddings_df["segment_content"] = embeddings_df["segment_content"].astype(str)
                embeddings_df["model_name"] = embeddings_df["model_name"].astype(str)

                embeddings_df.to_hdf(output_file, key="df", mode="w", format="table")
            else:
                embeddings_df = pd.DataFrame()

        from . import base_utils as _bu_internal  # import local para evitar ciclos en tiempo de carga
        doc_title = _bu_internal.extract_title_from_md(text if 'text' in locals() else bu.load_md(file), default=file_name)

        if export_numpy and not embeddings_df.empty:
            emb_values = embeddings_df.iloc[:, :-2].values.astype("float32")
            contents = embeddings_df["segment_content"].tolist()

            for local_idx, (vec, content) in enumerate(zip(emb_values, contents)):
                all_embeddings.append(vec)
                all_metadata.append(
                    {
                        "idx": global_idx,
                        "document_id": doc_id,
                        "document_title": doc_title,
                        "fragment_id": local_idx,
                        "content": content,
                    }
                )
                global_idx += 1

    if export_numpy and all_embeddings:
        numpy_output_dir = numpy_output_dir or os.path.join("data", "embeddings")
        os.makedirs(numpy_output_dir, exist_ok=True)

        embeddings_array = np.vstack(all_embeddings).astype("float32")
        np.save(os.path.join(numpy_output_dir, "embeddings.npy"), embeddings_array)

        metadata_path = os.path.join(numpy_output_dir, "metadata.jsonl")
        with open(metadata_path, "w", encoding="utf-8") as f:
            for meta in all_metadata:
                f.write(json.dumps(meta, ensure_ascii=False) + "\n")

        print(f"Exported consolidated embeddings to {numpy_output_dir}")

def search_query(query, corpus_embeddings, retrieval_model, segment_contents):

    query_embedding = retrieval_model.encode(query, convert_to_tensor=True)
    similarity_scores = retrieval_model.similarity(query_embedding, corpus_embeddings)[0]

    top_similarities, topk_indices = torch.topk(similarity_scores, k=bu.config['retrieve']['top_k'])
    top_segments = [segment_contents[idx] for idx in topk_indices]
    
    return top_segments, top_similarities

def load_embeddings(embeddings_dir):
    embeddings_list = []
    segment_contents_list = []
    model_names_set = set()

    num_documents = 0
    for file_path in glob.glob(os.path.join(embeddings_dir, "*.h5")):
        num_documents += 1
        embeddings_df = pd.read_hdf(file_path, key='df')
        embeddings = embeddings_df.iloc[:, :-2].values
        segment_contents = embeddings_df['segment_content'].values
        model_name = embeddings_df['model_name'].values[0]

        embeddings_list.extend(embeddings)
        segment_contents_list.extend(segment_contents)
        model_names_set.add(model_name)
    
    embeddings_array = np.array(embeddings_list)
    embeddings_tensor = torch.tensor(embeddings_array, dtype=torch.float32, device='cuda' if torch.cuda.is_available() else 'cpu')

    num_segment_contents = len(segment_contents_list)
    model_name = model_names_set.pop() if len(model_names_set) == 1 else "Multiple Models"

    return {
        "embeddings": embeddings_tensor,
        "segment_contents": segment_contents_list,
        "num_documents": num_documents,
        "num_segment_contents": num_segment_contents,
    }