Spaces:
Sleeping
Sleeping
File size: 5,126 Bytes
a0f20a0 99a81ef a0f20a0 2132962 a0f20a0 99a81ef a0f20a0 99a81ef a0f20a0 99a81ef a0f20a0 99a81ef a0f20a0 99a81ef a0f20a0 99a81ef a0f20a0 99a81ef a0f20a0 99a81ef a0f20a0 99a81ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import json
import os
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer
from config import MODEL_NAME, COLLECTION_NAME, EMBEDDING_DIM
load_dotenv()
def load_chunks(jsonl_path="chunks.jsonl"):
chunks = []
with open(jsonl_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
chunks.append(json.loads(line))
return chunks
def create_qdrant_client():
return QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
def load_embedding_model():
print("Loading embedding model...")
return SentenceTransformer(MODEL_NAME)
def create_collection(client, collection_name=COLLECTION_NAME, embedding_dim=EMBEDDING_DIM):
print(f"Creating collection '{collection_name}'...")
try:
if client.collection_exists(collection_name):
print(f"Collection '{collection_name}' exists. Deleting...")
client.delete_collection(collection_name)
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
return True
except Exception as e:
print(f"Error creating collection: {e}")
return False
def generate_embeddings(model, chunks):
print("Generating embeddings...")
texts = [chunk["text"] for chunk in chunks]
return model.encode(texts, show_progress_bar=True)
def create_points(chunks, embeddings):
points = []
for idx, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point = PointStruct(
id=idx,
vector=embedding.tolist(),
payload={
"url": chunk["url"],
"title": chunk["title"],
"date": chunk["date"],
"chunk_id": chunk["chunk_id"],
"text": chunk["text"],
}
)
points.append(point)
return points
def upload_points(client, points, collection_name=COLLECTION_NAME, batch_size=100):
print(f"Uploading {len(points)} points in batches of {batch_size}...")
total_batches = (len(points) + batch_size - 1) // batch_size
for i in range(0, len(points), batch_size):
batch = points[i:i + batch_size]
batch_num = (i // batch_size) + 1
print(f" Batch {batch_num}/{total_batches}: Uploading {len(batch)} points...")
client.upsert(collection_name=collection_name, points=batch)
print(f"✓ Uploaded {len(points)} chunks to collection '{collection_name}'")
def verify_upload(client, collection_name=COLLECTION_NAME):
collection_info = client.get_collection(collection_name)
print(f"Collection now has {collection_info.points_count} points")
return collection_info.points_count
def ensure_collection_exists(client, collection_name=COLLECTION_NAME, embedding_dim=EMBEDDING_DIM):
"""Ensure collection exists, create if it doesn't. Returns starting ID for new points."""
if not client.collection_exists(collection_name):
print(f"Collection '{collection_name}' doesn't exist. Creating...")
client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
return 0
else:
collection_info = client.get_collection(collection_name)
point_count = collection_info.points_count
print(f"Collection '{collection_name}' exists with {point_count} points")
return point_count
def offset_point_ids(points, start_id):
"""Update point IDs to start from a given offset."""
print(f"Setting point IDs starting from {start_id}...")
for i, point in enumerate(points):
point.id = start_id + i
return points
def print_upload_summary(start_id, added_count, new_count):
"""Print upload summary statistics."""
print(f"\n✓ Upload complete!")
print(f" Previous: {start_id} points")
print(f" Added: {added_count} points")
print(f" Total now: {new_count} points")
def upload_chunks_additive(chunks_file="chunks.jsonl"):
"""Upload chunks to Qdrant additively (preserves existing data)."""
if not os.path.exists(chunks_file):
print(f"Chunks file '{chunks_file}' not found")
return
chunks = load_chunks(chunks_file)
print(f"Found {len(chunks)} chunks")
if not chunks:
print("No chunks to upload")
return
client = create_qdrant_client()
start_id = ensure_collection_exists(client)
model = load_embedding_model()
embeddings = generate_embeddings(model, chunks)
points = create_points(chunks, embeddings)
points = offset_point_ids(points, start_id)
upload_points(client, points)
new_count = verify_upload(client)
print_upload_summary(start_id, len(points), new_count)
def main():
upload_chunks_additive()
if __name__ == "__main__":
main() |