vxa8502 commited on
Commit
26ca7c1
·
1 Parent(s): a68e765

Improve validation and resource handling

Browse files
Files changed (1) hide show
  1. scripts/kaggle_pipeline.py +47 -40
scripts/kaggle_pipeline.py CHANGED
@@ -30,7 +30,8 @@ if IS_KAGGLE:
30
 
31
  import subprocess
32
 
33
- packages = ["qdrant-client>=1.7.0", "sentence-transformers>=2.2.0"]
 
34
  for pkg in packages:
35
  subprocess.check_call(
36
  [sys.executable, "-m", "pip", "install", "-q", pkg],
@@ -50,7 +51,7 @@ else:
50
  load_dotenv()
51
  print("Using local .env")
52
 
53
- print(f"QDRANT_URL: {os.environ.get('QDRANT_URL', 'NOT SET')[:40]}...")
54
 
55
  # %% [markdown]
56
  # ## Check GPU
@@ -75,6 +76,7 @@ SUBSET_SIZE = 1_000_000 if IS_KAGGLE else 100_000
75
 
76
  print(f"Loading {SUBSET_SIZE:,} reviews...")
77
  start = time.time()
 
78
  df = prepare_data(subset_size=SUBSET_SIZE, force=True)
79
  print(f"Prepared {len(df):,} reviews in {time.time() - start:.1f}s")
80
 
@@ -111,6 +113,8 @@ print(f"Expansion ratio: {len(chunks) / len(reviews):.2f}x")
111
  # %%
112
  import numpy as np
113
 
 
 
114
  chunk_texts = [c.text for c in chunks]
115
 
116
  cache_dir = Path("/kaggle/working") if IS_KAGGLE else Path("data")
@@ -130,11 +134,16 @@ embed_time = time.time() - start
130
  print(f"Embeddings: {embeddings.shape} in {embed_time:.1f}s")
131
  print(f"Throughput: {len(chunks) / embed_time:.0f} chunks/sec")
132
 
133
- # Validate
134
- assert embeddings.shape[1] == 384, f"Wrong dims: {embeddings.shape[1]}"
135
- assert np.isnan(embeddings).sum() == 0, "NaN values"
 
 
 
 
136
  norms = np.linalg.norm(embeddings, axis=1)
137
- assert np.allclose(norms, 1.0, atol=0.01), "Not normalized"
 
138
  print("Validation: PASSED")
139
 
140
  # %% [markdown]
@@ -147,40 +156,38 @@ from sage.adapters.vector_store import (
147
  upload_chunks,
148
  get_collection_info,
149
  create_payload_indexes,
 
150
  )
151
 
152
- qdrant_url = os.environ.get("QDRANT_URL")
153
- print(f"Uploading to: {qdrant_url[:40]}...")
154
-
155
  client = get_client()
156
- create_collection(client)
157
- create_payload_indexes(client)
158
-
159
- start = time.time()
160
- upload_chunks(client, chunks, embeddings)
161
- print(f"Upload complete in {time.time() - start:.1f}s")
162
-
163
- info = get_collection_info(client)
164
- print("\nCollection info:")
165
- for key, value in info.items():
166
- print(f" {key}: {value}")
167
-
168
- # %% [markdown]
169
- # ## Test Search
170
-
171
- # %%
172
- from sage.adapters.vector_store import search
173
-
174
- query = "wireless headphones with noise cancellation"
175
- query_emb = embedder.embed_single_query(query)
176
- results = search(client, query_emb.tolist(), limit=5)
177
-
178
- print(f"Query: '{query}'\n")
179
- for i, r in enumerate(results):
180
- print(f"{i + 1}. [{r['rating']:.0f}*] {r['text'][:70]}...")
181
-
182
- # %%
183
- client.close()
184
- print(
185
- f"\nDone! {info.get('points_count', len(chunks)):,} chunks indexed to Qdrant Cloud"
186
- )
 
30
 
31
  import subprocess
32
 
33
+ # Pin exact versions matching requirements.txt for reproducibility
34
+ packages = ["qdrant-client==1.12.1", "sentence-transformers==3.3.1"]
35
  for pkg in packages:
36
  subprocess.check_call(
37
  [sys.executable, "-m", "pip", "install", "-q", pkg],
 
51
  load_dotenv()
52
  print("Using local .env")
53
 
54
+ print(f"QDRANT_URL: {'configured' if os.environ.get('QDRANT_URL') else 'NOT SET'}")
55
 
56
  # %% [markdown]
57
  # ## Check GPU
 
76
 
77
  print(f"Loading {SUBSET_SIZE:,} reviews...")
78
  start = time.time()
79
+ # Kaggle kernels are ephemeral - no persistent cache between runs, always regenerate
80
  df = prepare_data(subset_size=SUBSET_SIZE, force=True)
81
  print(f"Prepared {len(df):,} reviews in {time.time() - start:.1f}s")
82
 
 
113
  # %%
114
  import numpy as np
115
 
116
+ from sage.config import EMBEDDING_DIM
117
+
118
  chunk_texts = [c.text for c in chunks]
119
 
120
  cache_dir = Path("/kaggle/working") if IS_KAGGLE else Path("data")
 
134
  print(f"Embeddings: {embeddings.shape} in {embed_time:.1f}s")
135
  print(f"Throughput: {len(chunks) / embed_time:.0f} chunks/sec")
136
 
137
+ # Validate embeddings (explicit checks instead of assert - survives python -O)
138
+ if embeddings.shape[1] != EMBEDDING_DIM:
139
+ raise ValueError(
140
+ f"Wrong embedding dimensions: {embeddings.shape[1]}, expected {EMBEDDING_DIM}"
141
+ )
142
+ if np.isnan(embeddings).any() or np.isinf(embeddings).any():
143
+ raise ValueError("Embeddings contain NaN or Inf values")
144
  norms = np.linalg.norm(embeddings, axis=1)
145
+ if not np.allclose(norms, 1.0, atol=0.01):
146
+ raise ValueError("Embeddings are not normalized")
147
  print("Validation: PASSED")
148
 
149
  # %% [markdown]
 
156
  upload_chunks,
157
  get_collection_info,
158
  create_payload_indexes,
159
+ search,
160
  )
161
 
 
 
 
162
  client = get_client()
163
+ try:
164
+ create_collection(client)
165
+ create_payload_indexes(client)
166
+
167
+ start = time.time()
168
+ upload_chunks(client, chunks, embeddings)
169
+ print(f"Upload complete in {time.time() - start:.1f}s")
170
+
171
+ info = get_collection_info(client)
172
+ print("\nCollection info:")
173
+ for key, value in info.items():
174
+ print(f" {key}: {value}")
175
+
176
+ # %% [markdown]
177
+ # ## Test Search
178
+
179
+ # %%
180
+ query = "wireless headphones with noise cancellation"
181
+ query_emb = embedder.embed_single_query(query)
182
+ results = search(client, query_emb.tolist(), limit=5)
183
+
184
+ print(f"Query: '{query}'\n")
185
+ for i, r in enumerate(results):
186
+ print(f"{i + 1}. [{r['rating']:.0f}*] {r['text'][:70]}...")
187
+
188
+ # %%
189
+ print(
190
+ f"\nDone! {info.get('points_count', len(chunks)):,} chunks indexed to Qdrant Cloud"
191
+ )
192
+ finally:
193
+ client.close()