Image_generation / embedding_gen.py
manasdhir's picture
minor changes
5d1cbd9
import concurrent.futures
from typing import List, Dict, Any
from openai import AzureOpenAI
from config import azure_open_ai_key, azure_open_ai_url
client = AzureOpenAI(
api_key=azure_open_ai_key,
azure_endpoint=azure_open_ai_url,
api_version="2023-05-15"
)
def embed_text_batches_azure(
batches: List[List[str]],
model_name: str = "text-embedding-3-large",
max_workers: int = 2,
) -> Dict[str, Any]:
"""
Given a list of batches (each batch is a list of text chunks), send embedding requests
in parallel to Azure OpenAI, and return mapping of content to embeddings.
Args:
client: AzureOpenAI client instance (from openai import AzureOpenAI).
batches: List of batches, each batch is list of strings (text chunks).
model_name: The embedding model deployment name to use.
max_workers: Number of parallel threads to use for embedding requests.
Returns:
Dict with keys:
- "content": list of all text chunks (flattened)
- "embedding": list of embedding vectors corresponding to content
- "mapping": dict mapping content chunk string -> its embedding vector
"""
def embed_batch(text_batch: List[str]) -> List[List[float]]:
response = client.embeddings.create(
input=text_batch,
model=model_name,
dimensions=768
)
return [item.embedding for item in response.data]
all_embeddings = []
# Flatten all chunks preserving order to later create mapping
all_chunks = [chunk for batch in batches for chunk in batch]
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all batches and keep futures in submission order
futures = [executor.submit(embed_batch, batch) for batch in batches]
# Collect results in submission order to preserve mapping
for i, future in enumerate(futures, 1):
try:
embeddings = future.result()
all_embeddings.extend(embeddings)
print(f"Batch {i} embeddings received, total embeddings so far: {len(all_embeddings)}")
except Exception as e:
print(f"Error embedding batch {i}: {e}")
# Verify the length matches (each chunk should have an embedding)
if len(all_chunks) != len(all_embeddings):
raise ValueError(f"Mismatch between number of chunks ({len(all_chunks)}) "
f"and embeddings ({len(all_embeddings)}) received")
return {
"content": all_chunks,
"embedding": all_embeddings
}
import concurrent.futures
from typing import List, Dict, Any
from openai import AzureOpenAI
from config import azure_open_ai_key, azure_open_ai_url
client = AzureOpenAI(
api_key=azure_open_ai_key,
azure_endpoint=azure_open_ai_url,
api_version="2023-05-15"
)
def embed_docling_chunks_azure(
docling_chunks: List[Any], # List of docling chunk objects
model_name: str = "text-embedding-3-large",
max_workers: int = 2,
batch_size: int = 16
) -> Dict[str, Any]:
"""
Given a list of Docling chunk objects, extract text and metadata,
generate embeddings in batches, and return structured data for Qdrant upsert.
Args:
docling_chunks: List of Docling chunk objects with text and metadata.
model_name: The embedding model deployment name to use.
max_workers: Number of parallel threads to use for embedding requests.
batch_size: Number of chunks to process per batch for embedding.
Returns:
Dict with keys:
- "chunks_data": list of dicts with 'text', 'embedding', and 'metadata'
- "content": list of text content (for backward compatibility)
- "embedding": list of embedding vectors (for backward compatibility)
"""
def extract_chunk_data(chunk):
"""Extract text content and metadata from a Docling chunk."""
# Get the main text content
text_content = chunk.text if hasattr(chunk, 'text') else str(chunk)
# Extract metadata
metadata = {
"text": text_content,
"headings": []
}
# Extract headings if available (adjust based on your Docling chunk structure)
if hasattr(chunk, 'meta') and chunk.meta:
if hasattr(chunk.meta, 'headings'):
metadata["headings"] = chunk.meta.headings
elif hasattr(chunk.meta, 'heading'):
metadata["headings"] = [chunk.meta.heading]
# Add other metadata fields as needed
if hasattr(chunk, 'page_no'):
metadata["page_number"] = chunk.page_no
if hasattr(chunk, 'doc_items') and chunk.doc_items:
# Extract section headings from doc_items if available
headings = []
for item in chunk.doc_items:
if hasattr(item, 'label') and 'heading' in str(item.label).lower():
headings.append(item.text if hasattr(item, 'text') else str(item))
if headings:
metadata["headings"] = headings
return text_content, metadata
def embed_batch(text_batch: List[str]) -> List[List[float]]:
"""Generate embeddings for a batch of text chunks."""
response = client.embeddings.create(
input=text_batch,
model=model_name,
dimensions=768
)
return [item.embedding for item in response.data]
# Extract text and metadata from all chunks
all_text_content = []
all_metadata = []
for chunk in docling_chunks:
text_content, metadata = extract_chunk_data(chunk)
all_text_content.append(text_content)
all_metadata.append(metadata)
# Create batches for embedding
text_batches = [
all_text_content[i:i + batch_size]
for i in range(0, len(all_text_content), batch_size)
]
all_embeddings = []
# Generate embeddings in parallel batches
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(embed_batch, batch) for batch in text_batches]
for i, future in enumerate(futures, 1):
try:
embeddings = future.result()
all_embeddings.extend(embeddings)
print(f"Batch {i} embeddings received, total embeddings so far: {len(all_embeddings)}")
except Exception as e:
print(f"Error embedding batch {i}: {e}")
# Verify lengths match
if len(all_text_content) != len(all_embeddings):
raise ValueError(f"Mismatch between number of chunks ({len(all_text_content)}) "
f"and embeddings ({len(all_embeddings)}) received")
# Create structured data for Qdrant upsert
chunks_data = []
for i in range(len(all_text_content)):
chunks_data.append({
"text": all_text_content[i],
"embedding": all_embeddings[i],
"metadata": all_metadata[i]
})
return {
"chunks_data": chunks_data,
"content": all_text_content, # For backward compatibility
"embedding": all_embeddings # For backward compatibility
}