Improve validation and resource handling
Browse files- scripts/kaggle_pipeline.py +47 -40
scripts/kaggle_pipeline.py
CHANGED
|
@@ -30,7 +30,8 @@ if IS_KAGGLE:
|
|
| 30 |
|
| 31 |
import subprocess
|
| 32 |
|
| 33 |
-
|
|
|
|
| 34 |
for pkg in packages:
|
| 35 |
subprocess.check_call(
|
| 36 |
[sys.executable, "-m", "pip", "install", "-q", pkg],
|
|
@@ -50,7 +51,7 @@ else:
|
|
| 50 |
load_dotenv()
|
| 51 |
print("Using local .env")
|
| 52 |
|
| 53 |
-
print(f"QDRANT_URL: {os.environ.get('QDRANT_URL'
|
| 54 |
|
| 55 |
# %% [markdown]
|
| 56 |
# ## Check GPU
|
|
@@ -75,6 +76,7 @@ SUBSET_SIZE = 1_000_000 if IS_KAGGLE else 100_000
|
|
| 75 |
|
| 76 |
print(f"Loading {SUBSET_SIZE:,} reviews...")
|
| 77 |
start = time.time()
|
|
|
|
| 78 |
df = prepare_data(subset_size=SUBSET_SIZE, force=True)
|
| 79 |
print(f"Prepared {len(df):,} reviews in {time.time() - start:.1f}s")
|
| 80 |
|
|
@@ -111,6 +113,8 @@ print(f"Expansion ratio: {len(chunks) / len(reviews):.2f}x")
|
|
| 111 |
# %%
|
| 112 |
import numpy as np
|
| 113 |
|
|
|
|
|
|
|
| 114 |
chunk_texts = [c.text for c in chunks]
|
| 115 |
|
| 116 |
cache_dir = Path("/kaggle/working") if IS_KAGGLE else Path("data")
|
|
@@ -130,11 +134,16 @@ embed_time = time.time() - start
|
|
| 130 |
print(f"Embeddings: {embeddings.shape} in {embed_time:.1f}s")
|
| 131 |
print(f"Throughput: {len(chunks) / embed_time:.0f} chunks/sec")
|
| 132 |
|
| 133 |
-
# Validate
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
norms = np.linalg.norm(embeddings, axis=1)
|
| 137 |
-
|
|
|
|
| 138 |
print("Validation: PASSED")
|
| 139 |
|
| 140 |
# %% [markdown]
|
|
@@ -147,40 +156,38 @@ from sage.adapters.vector_store import (
|
|
| 147 |
upload_chunks,
|
| 148 |
get_collection_info,
|
| 149 |
create_payload_indexes,
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
-
qdrant_url = os.environ.get("QDRANT_URL")
|
| 153 |
-
print(f"Uploading to: {qdrant_url[:40]}...")
|
| 154 |
-
|
| 155 |
client = get_client()
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
)
|
|
|
|
| 30 |
|
| 31 |
import subprocess
|
| 32 |
|
| 33 |
+
# Pin exact versions matching requirements.txt for reproducibility
|
| 34 |
+
packages = ["qdrant-client==1.12.1", "sentence-transformers==3.3.1"]
|
| 35 |
for pkg in packages:
|
| 36 |
subprocess.check_call(
|
| 37 |
[sys.executable, "-m", "pip", "install", "-q", pkg],
|
|
|
|
| 51 |
load_dotenv()
|
| 52 |
print("Using local .env")
|
| 53 |
|
| 54 |
+
print(f"QDRANT_URL: {'configured' if os.environ.get('QDRANT_URL') else 'NOT SET'}")
|
| 55 |
|
| 56 |
# %% [markdown]
|
| 57 |
# ## Check GPU
|
|
|
|
| 76 |
|
| 77 |
print(f"Loading {SUBSET_SIZE:,} reviews...")
|
| 78 |
start = time.time()
|
| 79 |
+
# Kaggle kernels are ephemeral - no persistent cache between runs, always regenerate
|
| 80 |
df = prepare_data(subset_size=SUBSET_SIZE, force=True)
|
| 81 |
print(f"Prepared {len(df):,} reviews in {time.time() - start:.1f}s")
|
| 82 |
|
|
|
|
| 113 |
# %%
|
| 114 |
import numpy as np
|
| 115 |
|
| 116 |
+
from sage.config import EMBEDDING_DIM
|
| 117 |
+
|
| 118 |
chunk_texts = [c.text for c in chunks]
|
| 119 |
|
| 120 |
cache_dir = Path("/kaggle/working") if IS_KAGGLE else Path("data")
|
|
|
|
| 134 |
print(f"Embeddings: {embeddings.shape} in {embed_time:.1f}s")
|
| 135 |
print(f"Throughput: {len(chunks) / embed_time:.0f} chunks/sec")
|
| 136 |
|
| 137 |
+
# Validate embeddings (explicit checks instead of assert - survives python -O)
|
| 138 |
+
if embeddings.shape[1] != EMBEDDING_DIM:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"Wrong embedding dimensions: {embeddings.shape[1]}, expected {EMBEDDING_DIM}"
|
| 141 |
+
)
|
| 142 |
+
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
| 143 |
+
raise ValueError("Embeddings contain NaN or Inf values")
|
| 144 |
norms = np.linalg.norm(embeddings, axis=1)
|
| 145 |
+
if not np.allclose(norms, 1.0, atol=0.01):
|
| 146 |
+
raise ValueError("Embeddings are not normalized")
|
| 147 |
print("Validation: PASSED")
|
| 148 |
|
| 149 |
# %% [markdown]
|
|
|
|
| 156 |
upload_chunks,
|
| 157 |
get_collection_info,
|
| 158 |
create_payload_indexes,
|
| 159 |
+
search,
|
| 160 |
)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
| 162 |
client = get_client()
|
| 163 |
+
try:
|
| 164 |
+
create_collection(client)
|
| 165 |
+
create_payload_indexes(client)
|
| 166 |
+
|
| 167 |
+
start = time.time()
|
| 168 |
+
upload_chunks(client, chunks, embeddings)
|
| 169 |
+
print(f"Upload complete in {time.time() - start:.1f}s")
|
| 170 |
+
|
| 171 |
+
info = get_collection_info(client)
|
| 172 |
+
print("\nCollection info:")
|
| 173 |
+
for key, value in info.items():
|
| 174 |
+
print(f" {key}: {value}")
|
| 175 |
+
|
| 176 |
+
# %% [markdown]
|
| 177 |
+
# ## Test Search
|
| 178 |
+
|
| 179 |
+
# %%
|
| 180 |
+
query = "wireless headphones with noise cancellation"
|
| 181 |
+
query_emb = embedder.embed_single_query(query)
|
| 182 |
+
results = search(client, query_emb.tolist(), limit=5)
|
| 183 |
+
|
| 184 |
+
print(f"Query: '{query}'\n")
|
| 185 |
+
for i, r in enumerate(results):
|
| 186 |
+
print(f"{i + 1}. [{r['rating']:.0f}*] {r['text'][:70]}...")
|
| 187 |
+
|
| 188 |
+
# %%
|
| 189 |
+
print(
|
| 190 |
+
f"\nDone! {info.get('points_count', len(chunks)):,} chunks indexed to Qdrant Cloud"
|
| 191 |
+
)
|
| 192 |
+
finally:
|
| 193 |
+
client.close()
|