X1ng1 commited on
Commit
9d6d7fb
·
1 Parent(s): d784490

updated backend

Browse files
cluster_orchestrator.py CHANGED
@@ -186,13 +186,15 @@ class ClusterOrchestrator:
186
  # Use ALL messages in the conversation for context
187
  conversation_texts = [messages[i].text for i in message_indices]
188
 
189
- # Generate conversation label (shorter, more specific)
190
- if len(conversation_texts) <= 3:
191
- # For short conversations, use the first message as label
 
 
 
192
  first_msg = conversation_texts[0]
193
  if len(first_msg) > 60:
194
  truncated = first_msg[:60]
195
- # Try to break at word boundary
196
  last_space = truncated.rfind(' ')
197
  if last_space > 0:
198
  truncated = truncated[:last_space]
@@ -200,8 +202,13 @@ class ClusterOrchestrator:
200
  else:
201
  label = first_msg
202
  else:
203
- # For longer conversations, generate a summary label
204
- label = self.label_service.generate_cluster_label(conversation_texts[:10])
 
 
 
 
 
205
 
206
  # Get channel info (store in metadata, not in label)
207
  channel = messages[message_indices[0]].channel if message_indices else "unknown"
 
186
  # Use ALL messages in the conversation for context
187
  conversation_texts = [messages[i].text for i in message_indices]
188
 
189
+ # Always try to use LLM for better labels if we have enough content
190
+ # Only fall back to simple truncation for extremely short/empty convos
191
+ total_chars = sum(len(t) for t in conversation_texts)
192
+
193
+ if total_chars < 50:
194
+ # Very short conversation: use first message
195
  first_msg = conversation_texts[0]
196
  if len(first_msg) > 60:
197
  truncated = first_msg[:60]
 
198
  last_space = truncated.rfind(' ')
199
  if last_space > 0:
200
  truncated = truncated[:last_space]
 
202
  else:
203
  label = first_msg
204
  else:
205
+ # Use Gemini for proper labeling of the conversation
206
+ # This fixes the issue of "I'll Have Let's" type labels
207
+ label = self.label_service.generate_cluster_label(
208
+ conversation_texts,
209
+ max_messages=10, # Fewer messages needed for single conversation
210
+ max_length=40 # Shorter labels for leaf nodes
211
+ )
212
 
213
  # Get channel info (store in metadata, not in label)
214
  channel = messages[message_indices[0]].channel if message_indices else "unknown"
gemini_label_service.py CHANGED
@@ -48,30 +48,33 @@ class GeminiLabelService:
48
  def generate_cluster_label(
49
  self,
50
  messages: List[str],
51
- max_messages: int = 10,
52
- max_length: int = 50
53
  ) -> str:
54
  """Generate a descriptive label for a cluster"""
55
  if not messages:
56
  return "Empty Cluster"
57
 
58
  selected = messages[:max_messages]
59
- messages_text = "\n".join([f"- {msg[:150]}" for msg in selected])
 
60
 
61
- prompt = f"""Analyze these chat messages and create a clear, descriptive topic label in 3-6 words.
62
- Be specific and concise.
 
 
63
 
64
  Messages:
65
  {messages_text}
66
 
67
- Topic label (3-6 words):"""
68
 
69
  try:
70
  response = self.model.generate_content(
71
  prompt,
72
  generation_config=genai.types.GenerationConfig(
73
- max_output_tokens=20,
74
- temperature=0.7,
75
  )
76
  )
77
 
@@ -124,10 +127,20 @@ Keywords:"""
124
 
125
  def _clean_label(self, label: str) -> str:
126
  """Clean and format label"""
127
- label = label.replace("Topic:", "").replace("topic:", "").strip()
 
 
 
 
 
 
 
 
 
128
  if label and not label[0].isupper():
129
  label = label[0].upper() + label[1:]
130
- return label[:50] if label else "General Discussion"
 
131
 
132
  def _clean_tag(self, tag: str) -> str:
133
  """Clean a tag"""
@@ -136,19 +149,38 @@ Keywords:"""
136
 
137
  def _fallback_label(self, messages: List[str]) -> str:
138
  """Simple fallback if API fails"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  from collections import Counter
140
  words = []
141
  for msg in messages:
142
  words.extend(msg.lower().split())
143
 
144
- stopwords = {"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for"}
145
  words = [w for w in words if w not in stopwords and len(w) > 3]
146
 
147
  if not words:
148
  return "General Discussion"
149
 
150
- common = Counter(words).most_common(3)
151
- return " ".join([word.capitalize() for word, _ in common])
152
 
153
  def _fallback_tags(self, messages: List[str], num_tags: int) -> List[str]:
154
  """Simple fallback tags"""
 
48
  def generate_cluster_label(
49
  self,
50
  messages: List[str],
51
+ max_messages: int = 30,
52
+ max_length: int = 60
53
  ) -> str:
54
  """Generate a descriptive label for a cluster"""
55
  if not messages:
56
  return "Empty Cluster"
57
 
58
  selected = messages[:max_messages]
59
+ # Allow slightly longer context per message
60
+ messages_text = "\n".join([f"- {msg[:200]}" for msg in selected])
61
 
62
+ prompt = f"""Analyze these chat messages from a team collaboration channel.
63
+ Identify the main project, specific technical issue, or key activity being discussed.
64
+ Create a descriptive, specific title (4-8 words) that clearly distinguishes this topic.
65
+ Avoid generic phrases like "Team Discussion" or "Project Update".
66
 
67
  Messages:
68
  {messages_text}
69
 
70
+ Specific Topic Title:"""
71
 
72
  try:
73
  response = self.model.generate_content(
74
  prompt,
75
  generation_config=genai.types.GenerationConfig(
76
+ max_output_tokens=30,
77
+ temperature=0.4, # Lower temperature for more focused results
78
  )
79
  )
80
 
 
127
 
128
  def _clean_label(self, label: str) -> str:
129
  """Clean and format label"""
130
+ # Remove common prefixes/suffixes from LLM output
131
+ prefixes = ["Title:", "Label:", "Topic:", "Subject:", "The topic is", "Discussion about"]
132
+ for prefix in prefixes:
133
+ if label.lower().startswith(prefix.lower()):
134
+ label = label[len(prefix):].strip()
135
+
136
+ # Remove quotes if present
137
+ label = label.strip('"\'')
138
+
139
+ # Capitalize first letter
140
  if label and not label[0].isupper():
141
  label = label[0].upper() + label[1:]
142
+
143
+ return label[:60] if label else "General Discussion"
144
 
145
  def _clean_tag(self, tag: str) -> str:
146
  """Clean a tag"""
 
149
 
150
  def _fallback_label(self, messages: List[str]) -> str:
151
  """Simple fallback if API fails"""
152
+ if not messages:
153
+ return "General Discussion"
154
+
155
+ # Try to use the beginning of the first substantial message
156
+ for msg in messages:
157
+ if len(msg) > 20:
158
+ # Find first sentence or up to 50 chars
159
+ end = msg.find('.')
160
+ if end > 0:
161
+ candidate = msg[:end+1]
162
+ else:
163
+ candidate = msg
164
+
165
+ if len(candidate) > 60:
166
+ candidate = candidate[:60].rsplit(' ', 1)[0] + "..."
167
+
168
+ return candidate
169
+
170
+ # Fallback to word counter if all messages are tiny
171
  from collections import Counter
172
  words = []
173
  for msg in messages:
174
  words.extend(msg.lower().split())
175
 
176
+ stopwords = {"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "is", "are", "of", "with"}
177
  words = [w for w in words if w not in stopwords and len(w) > 3]
178
 
179
  if not words:
180
  return "General Discussion"
181
 
182
+ common = Counter(words).most_common(2)
183
+ return " & ".join([word.capitalize() for word, _ in common])
184
 
185
  def _fallback_tags(self, messages: List[str], num_tags: int) -> List[str]:
186
  """Simple fallback tags"""
hierarchical_clustering_service.py CHANGED
@@ -107,7 +107,11 @@ class HierarchicalClusteringService:
107
 
108
  # Cluster conversations by semantic similarity
109
  if len(conversations) > 1:
110
- topic_labels = self._create_topic_clusters(conversation_embeddings, len(conversations))
 
 
 
 
111
  else:
112
  topic_labels = np.array([0])
113
 
@@ -160,37 +164,45 @@ class HierarchicalClusteringService:
160
  def _create_topic_clusters(
161
  self,
162
  conversation_embeddings: np.ndarray,
163
- n_conversations: int
 
164
  ) -> np.ndarray:
165
  """
166
  Cluster conversations by topic using semantic similarity.
167
  Returns cluster labels for each conversation as numpy array.
168
 
169
- Note: main_cluster_threshold (default 1.2) controls topic granularity.
170
- Lower values = fewer, broader topics; higher values = more, specific topics.
171
  """
172
  if n_conversations < 2:
173
  return np.array([0])
174
 
175
- # First, try with threshold-based clustering
176
- labels = self._cluster_level(
177
- conversation_embeddings,
178
- self.main_cluster_threshold,
179
- self.min_main_cluster_size
180
- )
 
 
 
 
 
181
 
182
- n_clusters = len(np.unique(labels))
 
183
 
184
- # If we have too many clusters, use n_clusters parameter instead
185
- if n_clusters > self.max_clusters:
186
- logger.info(f"Threshold produced {n_clusters} topic clusters, limiting to {self.max_clusters}")
187
  clustering = AgglomerativeClustering(
188
- n_clusters=self.max_clusters,
189
- linkage='ward'
 
190
  )
191
  labels = clustering.fit_predict(conversation_embeddings)
192
-
193
- return labels
 
 
 
194
 
195
  def _cluster_level(
196
  self,
 
107
 
108
  # Cluster conversations by semantic similarity
109
  if len(conversations) > 1:
110
+ topic_labels = self._create_topic_clusters(
111
+ conversation_embeddings,
112
+ len(conversations),
113
+ n_messages
114
+ )
115
  else:
116
  topic_labels = np.array([0])
117
 
 
164
  def _create_topic_clusters(
165
  self,
166
  conversation_embeddings: np.ndarray,
167
+ n_conversations: int,
168
+ n_messages: int
169
  ) -> np.ndarray:
170
  """
171
  Cluster conversations by topic using semantic similarity.
172
  Returns cluster labels for each conversation as numpy array.
173
 
174
+ Uses dynamic cluster counting: ~5% of total messages, capped by max_clusters.
 
175
  """
176
  if n_conversations < 2:
177
  return np.array([0])
178
 
179
+ # Calculate target number of clusters based on message count (5% rule)
180
+ # Example: 100 messages -> 5 clusters
181
+ target_n_clusters = int(n_messages * 0.05)
182
+
183
+ # Ensure reasonable bounds
184
+ # At least 2 clusters (if we have enough conversations)
185
+ # At most max_clusters
186
+ # At most n_conversations (can't have more clusters than items)
187
+ target_n_clusters = max(2, target_n_clusters)
188
+ target_n_clusters = min(target_n_clusters, self.max_clusters)
189
+ target_n_clusters = min(target_n_clusters, n_conversations)
190
 
191
+ logger.info(f"Clustering {n_conversations} conversations into {target_n_clusters} topics "
192
+ f"(based on {n_messages} messages)")
193
 
194
+ try:
 
 
195
  clustering = AgglomerativeClustering(
196
+ n_clusters=target_n_clusters,
197
+ linkage='ward',
198
+ metric='euclidean'
199
  )
200
  labels = clustering.fit_predict(conversation_embeddings)
201
+ return labels
202
+ except Exception as e:
203
+ logger.error(f"Topic clustering failed: {e}")
204
+ # Fallback to single cluster
205
+ return np.zeros(n_conversations, dtype=int)
206
 
207
  def _cluster_level(
208
  self,
main.py CHANGED
@@ -5,6 +5,7 @@ from fastapi import FastAPI, HTTPException, BackgroundTasks
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.responses import JSONResponse
7
  from typing import Dict, Optional
 
8
  import logging
9
  import uuid
10
 
@@ -49,8 +50,9 @@ app.include_router(slack_oauth_router)
49
  app.include_router(discord_oauth_router)
50
 
51
  # Job storage (in production, use Redis or database)
52
- jobs: Dict[str, ClusteringStatus] = {}
53
- results: Dict[str, ClusteringOutput] = {}
 
54
 
55
  @app.on_event("startup")
56
  async def startup_event():
@@ -72,6 +74,7 @@ async def startup_event():
72
  logger.warning(f"Warmup failed ({type(e).__name__}): {e}")
73
  except Exception as e:
74
  logger.warning(f"Warmup failed with unexpected error ({type(e).__name__}): {e}")
 
75
 
76
  @app.get("/")
77
  async def root():
@@ -145,11 +148,20 @@ async def cluster_messages_async(
145
  job_id = str(uuid.uuid4())
146
 
147
  # Initialize job status
148
- jobs[job_id] = ClusteringStatus(
 
 
 
 
 
 
 
149
  status="processing",
150
  progress=0.0,
151
  message="Starting clustering job",
152
- job_id=job_id
 
 
153
  )
154
 
155
  # Add background task
@@ -165,13 +177,23 @@ async def cluster_messages_async(
165
  async def process_clustering_job(job_id: str, request: ClusteringRequest):
166
  """Background task for clustering"""
167
  try:
168
- jobs[job_id].message = "Processing messages..."
169
- jobs[job_id].progress = 10.0
 
 
 
 
 
170
 
171
  orchestrator = get_orchestrator()
172
 
173
- jobs[job_id].message = "Generating embeddings..."
174
- jobs[job_id].progress = 30.0
 
 
 
 
 
175
 
176
  result = orchestrator.process_messages(
177
  messages=request.messages,
@@ -180,43 +202,62 @@ async def process_clustering_job(job_id: str, request: ClusteringRequest):
180
  min_cluster_size=request.min_cluster_size
181
  )
182
 
183
- jobs[job_id].message = "Clustering complete"
184
- jobs[job_id].progress = 100.0
185
- jobs[job_id].status = "completed"
186
-
187
- results[job_id] = result
 
 
 
 
 
 
188
 
189
  except Exception as e:
190
  logger.error(f"Error in background job {job_id}: {e}", exc_info=True)
191
- jobs[job_id].status = "error"
192
- jobs[job_id].message = str(e)
 
 
 
 
 
193
 
194
 
195
  @app.get("/cluster/status/{job_id}", response_model=ClusteringStatus)
196
  async def get_job_status(job_id: str):
197
  """Get status of a clustering job"""
198
- if job_id not in jobs:
199
- raise HTTPException(status_code=404, detail="Job not found")
200
-
201
- return jobs[job_id]
 
 
202
 
203
 
204
  @app.get("/cluster/result/{job_id}", response_model=ClusteringOutput)
205
  async def get_job_result(job_id: str):
206
  """Get result of a completed clustering job"""
207
- if job_id not in jobs:
208
- raise HTTPException(status_code=404, detail="Job not found")
 
 
209
 
210
- if jobs[job_id].status != "completed":
 
211
  raise HTTPException(
212
  status_code=400,
213
- detail=f"Job not completed. Current status: {jobs[job_id].status}"
214
  )
215
 
216
- if job_id not in results:
217
- raise HTTPException(status_code=404, detail="Result not found")
218
-
219
- return results[job_id]
 
 
 
220
 
221
 
222
  @app.post("/search", response_model=list[SearchResult])
@@ -231,12 +272,48 @@ async def search_messages(request: SearchRequest):
231
  List of search results
232
  """
233
  try:
234
- # This endpoint requires messages to be provided or stored
235
- # For now, return error - in production, integrate with database
236
- raise HTTPException(
237
- status_code=501,
238
- detail="Search endpoint requires integration with message storage"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  )
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  except HTTPException:
242
  raise
@@ -363,6 +440,17 @@ async def fetch_slack_messages(request: SlackFetchRequest):
363
  raise HTTPException(status_code=500, detail=str(e))
364
 
365
 
 
 
 
 
 
 
 
 
 
 
 
366
  if __name__ == "__main__":
367
  import uvicorn
368
 
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.responses import JSONResponse
7
  from typing import Dict, Optional
8
+ from supabase_job_storage import get_job_storage
9
  import logging
10
  import uuid
11
 
 
50
  app.include_router(discord_oauth_router)
51
 
52
  # Job storage (in production, use Redis or database)
53
+ #jobs: Dict[str, ClusteringStatus] = {}
54
+ #results: Dict[str, ClusteringOutput] = {}
55
+ storage = get_job_storage()
56
 
57
  @app.on_event("startup")
58
  async def startup_event():
 
74
  logger.warning(f"Warmup failed ({type(e).__name__}): {e}")
75
  except Exception as e:
76
  logger.warning(f"Warmup failed with unexpected error ({type(e).__name__}): {e}")
77
+ storage.cleanup_old_jobs()
78
 
79
  @app.get("/")
80
  async def root():
 
148
  job_id = str(uuid.uuid4())
149
 
150
  # Initialize job status
151
+ # jobs[job_id] = ClusteringStatus(
152
+ # status="processing",
153
+ # progress=0.0,
154
+ # message="Starting clustering job",
155
+ # job_id=job_id
156
+ # )
157
+ storage.create_job(
158
+ job_id=job_id,
159
  status="processing",
160
  progress=0.0,
161
  message="Starting clustering job",
162
+ distance_threshold=request.distance_threshold,
163
+ min_cluster_size=request.min_cluster_size,
164
+ force_recluster=request.force_recluster
165
  )
166
 
167
  # Add background task
 
177
  async def process_clustering_job(job_id: str, request: ClusteringRequest):
178
  """Background task for clustering"""
179
  try:
180
+ # jobs[job_id].message = "Processing messages..."
181
+ # jobs[job_id].progress = 10.0
182
+ storage.update_job_status(
183
+ job_id=job_id,
184
+ message="Processing messages...",
185
+ progress=10.0
186
+ )
187
 
188
  orchestrator = get_orchestrator()
189
 
190
+ # jobs[job_id].message = "Generating embeddings..."
191
+ # jobs[job_id].progress = 30.0
192
+ storage.update_job_status(
193
+ job_id=job_id,
194
+ message="Generating embeddings...",
195
+ progress=30.0
196
+ )
197
 
198
  result = orchestrator.process_messages(
199
  messages=request.messages,
 
202
  min_cluster_size=request.min_cluster_size
203
  )
204
 
205
+ # jobs[job_id].message = "Clustering complete"
206
+ # jobs[job_id].progress = 100.0
207
+ #jobs[job_id].status = "completed"
208
+ #results[job_id] = result
209
+ storage.save_result(job_id, result.dict())
210
+ storage.update_job_status(
211
+ job_id=job_id,
212
+ message="Clustering complete",
213
+ progress=100.0,
214
+ status="completed"
215
+ )
216
 
217
  except Exception as e:
218
  logger.error(f"Error in background job {job_id}: {e}", exc_info=True)
219
+ # jobs[job_id].status = "error"
220
+ # jobs[job_id].message = str(e)
221
+ storage.update_job_status(
222
+ job_id=job_id,
223
+ message=str(e),
224
+ status="error"
225
+ )
226
 
227
 
228
  @app.get("/cluster/status/{job_id}", response_model=ClusteringStatus)
229
  async def get_job_status(job_id: str):
230
  """Get status of a clustering job"""
231
+ # if job_id not in jobs:
232
+ # return jobs[job_id]
233
+ job_data = storage.get_job(job_id)
234
+ if not job_data:
235
+ raise HTTPException(status_code=404, detail="Job not found or expired")
236
+ return ClusteringStatus(**job_data)
237
 
238
 
239
  @app.get("/cluster/result/{job_id}", response_model=ClusteringOutput)
240
  async def get_job_result(job_id: str):
241
  """Get result of a completed clustering job"""
242
+ # if job_id not in jobs:
243
+ job_data = storage.get_job(job_id)
244
+ if not job_data:
245
+ raise HTTPException(status_code=404, detail="Job not found or expired")
246
 
247
+ # if jobs[job_id].status != "completed":
248
+ if job_data["status"] != "completed":
249
  raise HTTPException(
250
  status_code=400,
251
+ detail=f"Job not completed. Current status: {job_data['status']}"
252
  )
253
 
254
+ # if job_id not in results:
255
+ # raise HTTPException(status_code=404, detail="Result not found")
256
+ # return results[job_id]
257
+ result_data = storage.get_result(job_id)
258
+ if not result_data:
259
+ raise HTTPException(status_code=404, detail="Result not found or expired")
260
+ return ClusteringOutput(**result_data)
261
 
262
 
263
  @app.post("/search", response_model=list[SearchResult])
 
272
  List of search results
273
  """
274
  try:
275
+ orchestrator = get_orchestrator()
276
+
277
+ # If messages are provided in request (not ideal but works for small batches)
278
+ # In a real app, we'd use a job_id or session_id to retrieve stored messages
279
+ if not request.messages_with_tags:
280
+ # Fallback: check if we have a recent result in memory (simple stateful approach)
281
+ # This is a hack for the demo; in prod use a DB
282
+
283
+ # if results:
284
+ # last_job_id = list(results.keys())[-1]
285
+ # request.messages_with_tags = results[last_job_id].messages
286
+ recent_jobs = storage.get_recent_jobs(limit=1)
287
+ if recent_jobs:
288
+ result_data = storage.get_result(recent_jobs[0]["job_id"])
289
+ if result_data:
290
+ request.messages_with_tags = [
291
+ MessageWithTags(**msg) for msg in result_data.get("messages", [])
292
+ ]
293
+ else:
294
+ raise HTTPException(
295
+ status_code=400,
296
+ detail="No context provided for search. Please run clustering first."
297
+ )
298
+
299
+ results_tuples = orchestrator.search_messages(
300
+ query=request.query,
301
+ messages_with_tags=request.messages_with_tags,
302
+ filter_tags=request.filter_tags,
303
+ filter_clusters=request.filter_clusters,
304
+ top_k=request.top_k
305
  )
306
+
307
+ # Convert tuples to SearchResult objects
308
+ search_results = [
309
+ SearchResult(
310
+ message=msg,
311
+ score=score
312
+ )
313
+ for msg, score in results_tuples
314
+ ]
315
+
316
+ return search_results
317
 
318
  except HTTPException:
319
  raise
 
440
  raise HTTPException(status_code=500, detail=str(e))
441
 
442
 
443
+ @app.post("/admin/cleanup-jobs")
444
+ async def cleanup_old_jobs():
445
+ """Manually trigger cleanup of old jobs (>48 hours)"""
446
+ deleted_count = storage.cleanup_old_jobs()
447
+ return {
448
+ "status": "success",
449
+ "deleted_jobs": deleted_count,
450
+ "message": f"Cleaned up {deleted_count} old jobs"
451
+ }
452
+
453
+
454
  if __name__ == "__main__":
455
  import uvicorn
456
 
models.py CHANGED
@@ -73,12 +73,15 @@ class SearchRequest(BaseModel):
73
  top_k: int = Field(10, description="Number of results to return")
74
  filter_tags: Optional[List[str]] = Field(None, description="Filter by specific tags")
75
  filter_clusters: Optional[List[str]] = Field(None, description="Filter by specific clusters")
 
 
76
 
77
  class SearchResult(BaseModel):
78
  """Single search result"""
79
  message: MessageWithTags
80
- similarity_score: float = Field(..., description="Similarity score to query")
81
- rank: int = Field(..., description="Result rank")
 
82
 
83
  class SlackFetchRequest(BaseModel):
84
  """Request model for fetching Slack messages"""
 
73
  top_k: int = Field(10, description="Number of results to return")
74
  filter_tags: Optional[List[str]] = Field(None, description="Filter by specific tags")
75
  filter_clusters: Optional[List[str]] = Field(None, description="Filter by specific clusters")
76
+ # Add optional context messages for the search
77
+ messages_with_tags: Optional[List[MessageWithTags]] = Field(None, description="Context messages to search within")
78
 
79
  class SearchResult(BaseModel):
80
  """Single search result"""
81
  message: MessageWithTags
82
+ score: float = Field(..., description="Similarity score to query")
83
+ # rank is not strictly needed if we return a list, but keeping it if UI uses it
84
+ # rank: int = Field(..., description="Result rank") # Removing rank as it's implied by order
85
 
86
  class SlackFetchRequest(BaseModel):
87
  """Request model for fetching Slack messages"""
supabase_job_storage.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supabase-based job storage for clustering jobs.
3
+ Prevents memory leaks by storing jobs in PostgreSQL instead of in-memory dictionaries.
4
+ """
5
+ import json
6
+ import logging
7
+ from datetime import datetime, timedelta
8
+ from typing import Optional, Dict, Any, List
9
+ from uuid import UUID
10
+
11
+ from tenacity import retry, stop_after_attempt, wait_exponential
12
+ from pydantic import BaseModel
13
+
14
+ from database import get_client
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class SupabaseJobStorage:
20
+ """
21
+ Manages clustering job storage in Supabase (PostgreSQL).
22
+
23
+ Features:
24
+ - Persistent storage (survives server restarts)
25
+ - Automatic cleanup of old jobs
26
+ - Supports multiple servers
27
+ - Transaction safety
28
+ """
29
+
30
+ def __init__(self, retention_hours: int = 48):
31
+ """
32
+ Initialize Supabase job storage.
33
+
34
+ Args:
35
+ retention_hours: How long to keep jobs before cleanup (default 48 hours)
36
+ """
37
+ self.client = get_client()
38
+ self.retention_hours = retention_hours
39
+
40
+ if self.client is None:
41
+ logger.warning("Supabase client not configured. Job storage will fail.")
42
+
43
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
44
+ def create_job(
45
+ self,
46
+ job_id: str,
47
+ status: str = "processing",
48
+ progress: float = 0.0,
49
+ message: str = "",
50
+ user_id: Optional[str] = None,
51
+ distance_threshold: Optional[float] = None,
52
+ min_cluster_size: Optional[int] = None,
53
+ force_recluster: bool = False
54
+ ) -> bool:
55
+ """Create a new clustering job in the database."""
56
+ if self.client is None:
57
+ return False
58
+
59
+ try:
60
+ payload = {
61
+ "job_id": job_id,
62
+ "status": status,
63
+ "progress": progress,
64
+ "message": message,
65
+ "user_id": user_id,
66
+ "distance_threshold": distance_threshold,
67
+ "min_cluster_size": min_cluster_size,
68
+ "force_recluster": force_recluster
69
+ }
70
+
71
+ response = self.client.table("clustering_jobs").insert(payload).execute()
72
+ logger.info(f"Created job {job_id} in database")
73
+ return True
74
+ except Exception as e:
75
+ logger.error(f"Failed to create job {job_id}: {e}")
76
+ return False
77
+
78
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
79
+ def update_job_status(
80
+ self,
81
+ job_id: str,
82
+ status: Optional[str] = None,
83
+ progress: Optional[float] = None,
84
+ message: Optional[str] = None
85
+ ) -> bool:
86
+ """Update job status, progress, or message."""
87
+ if self.client is None:
88
+ return False
89
+
90
+ try:
91
+ updates: Dict[str, Any] = {}
92
+ if status is not None:
93
+ updates["status"] = status
94
+ if status == "completed":
95
+ updates["completed_at"] = datetime.utcnow().isoformat()
96
+ if progress is not None:
97
+ updates["progress"] = progress
98
+ if message is not None:
99
+ updates["message"] = message
100
+
101
+ if not updates:
102
+ return True
103
+
104
+ response = self.client.table("clustering_jobs")\
105
+ .update(updates)\
106
+ .eq("job_id", job_id)\
107
+ .execute()
108
+
109
+ return True
110
+ except Exception as e:
111
+ logger.error(f"Failed to update job {job_id}: {e}")
112
+ return False
113
+
114
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
115
+ def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
116
+ """Get job by ID."""
117
+ if self.client is None:
118
+ return None
119
+
120
+ try:
121
+ response = self.client.table("clustering_jobs")\
122
+ .select("*")\
123
+ .eq("job_id", job_id)\
124
+ .execute()
125
+
126
+ if response.data and len(response.data) > 0:
127
+ return response.data[0]
128
+ return None
129
+ except Exception as e:
130
+ logger.error(f"Failed to get job {job_id}: {e}")
131
+ return None
132
+
133
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
134
+ def save_result(self, job_id: str, result_data: Dict[str, Any]) -> bool:
135
+ """Save clustering result for a job."""
136
+ if self.client is None:
137
+ return False
138
+
139
+ try:
140
+ payload = {
141
+ "job_id": job_id,
142
+ "result_data": result_data # Supabase automatically handles JSONB
143
+ }
144
+
145
+ # Use upsert to handle both insert and update cases
146
+ response = self.client.table("clustering_results")\
147
+ .upsert(payload)\
148
+ .execute()
149
+
150
+ logger.info(f"Saved result for job {job_id}")
151
+ return True
152
+ except Exception as e:
153
+ logger.error(f"Failed to save result for job {job_id}: {e}")
154
+ return False
155
+
156
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
157
+ def get_result(self, job_id: str) -> Optional[Dict[str, Any]]:
158
+ """Get clustering result by job ID."""
159
+ if self.client is None:
160
+ return None
161
+
162
+ try:
163
+ response = self.client.table("clustering_results")\
164
+ .select("result_data")\
165
+ .eq("job_id", job_id)\
166
+ .execute()
167
+
168
+ if response.data and len(response.data) > 0:
169
+ return response.data[0]["result_data"]
170
+ return None
171
+ except Exception as e:
172
+ logger.error(f"Failed to get result for job {job_id}: {e}")
173
+ return None
174
+
175
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
176
+ def cleanup_old_jobs(self, hours: Optional[int] = None) -> int:
177
+ """
178
+ Delete jobs older than specified hours.
179
+
180
+ Args:
181
+ hours: Retention period in hours (uses instance default if None)
182
+
183
+ Returns:
184
+ Number of jobs deleted
185
+ """
186
+ if self.client is None:
187
+ return 0
188
+
189
+ hours = hours or self.retention_hours
190
+
191
+ try:
192
+ # Call the PostgreSQL function
193
+ response = self.client.rpc("cleanup_old_clustering_jobs", {
194
+ "retention_hours": hours
195
+ }).execute()
196
+
197
+ deleted_count = response.data if response.data is not None else 0
198
+
199
+ if deleted_count > 0:
200
+ logger.info(f"Cleaned up {deleted_count} old clustering jobs")
201
+
202
+ return deleted_count
203
+ except Exception as e:
204
+ # Fallback: manual deletion if function doesn't exist
205
+ logger.warning(f"RPC function not available, using manual cleanup: {e}")
206
+
207
+ try:
208
+ cutoff = (datetime.utcnow() - timedelta(hours=hours)).isoformat()
209
+
210
+ # Get jobs to delete
211
+ old_jobs = self.client.table("clustering_jobs")\
212
+ .select("job_id")\
213
+ .lt("created_at", cutoff)\
214
+ .execute()
215
+
216
+ job_ids = [job["job_id"] for job in (old_jobs.data or [])]
217
+
218
+ if job_ids:
219
+ # Delete results first (due to foreign key)
220
+ self.client.table("clustering_results")\
221
+ .delete()\
222
+ .in_("job_id", job_ids)\
223
+ .execute()
224
+
225
+ # Delete jobs
226
+ self.client.table("clustering_jobs")\
227
+ .delete()\
228
+ .in_("job_id", job_ids)\
229
+ .execute()
230
+
231
+ logger.info(f"Manually cleaned up {len(job_ids)} old clustering jobs")
232
+ return len(job_ids)
233
+
234
+ return 0
235
+ except Exception as e2:
236
+ logger.error(f"Manual cleanup also failed: {e2}")
237
+ return 0
238
+
239
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
240
+ def get_recent_jobs(self, limit: int = 10, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
241
+ """
242
+ Get recent jobs, optionally filtered by user.
243
+
244
+ Args:
245
+ limit: Maximum number of jobs to return
246
+ user_id: Optional user ID to filter by
247
+
248
+ Returns:
249
+ List of job dictionaries
250
+ """
251
+ if self.client is None:
252
+ return []
253
+
254
+ try:
255
+ query = self.client.table("clustering_jobs")\
256
+ .select("*")\
257
+ .order("created_at", desc=True)\
258
+ .limit(limit)
259
+
260
+ if user_id:
261
+ query = query.eq("user_id", user_id)
262
+
263
+ response = query.execute()
264
+ return response.data or []
265
+ except Exception as e:
266
+ logger.error(f"Failed to get recent jobs: {e}")
267
+ return []
268
+
269
+ def get_stats(self) -> Dict[str, Any]:
270
+ """Get storage statistics."""
271
+ if self.client is None:
272
+ return {
273
+ "total_jobs": 0,
274
+ "active_jobs": 0,
275
+ "completed_jobs": 0,
276
+ "error_jobs": 0,
277
+ "storage_type": "supabase (not connected)"
278
+ }
279
+
280
+ try:
281
+ # Count jobs by status
282
+ all_jobs = self.client.table("clustering_jobs")\
283
+ .select("status", count="exact")\
284
+ .execute()
285
+
286
+ total = all_jobs.count if hasattr(all_jobs, 'count') else len(all_jobs.data or [])
287
+
288
+ active = self.client.table("clustering_jobs")\
289
+ .select("job_id", count="exact")\
290
+ .eq("status", "processing")\
291
+ .execute()
292
+
293
+ completed = self.client.table("clustering_jobs")\
294
+ .select("job_id", count="exact")\
295
+ .eq("status", "completed")\
296
+ .execute()
297
+
298
+ errors = self.client.table("clustering_jobs")\
299
+ .select("job_id", count="exact")\
300
+ .eq("status", "error")\
301
+ .execute()
302
+
303
+ return {
304
+ "total_jobs": total,
305
+ "active_jobs": len(active.data or []),
306
+ "completed_jobs": len(completed.data or []),
307
+ "error_jobs": len(errors.data or []),
308
+ "storage_type": "supabase (postgresql)",
309
+ "retention_hours": self.retention_hours
310
+ }
311
+ except Exception as e:
312
+ logger.error(f"Failed to get stats: {e}")
313
+ return {
314
+ "total_jobs": 0,
315
+ "active_jobs": 0,
316
+ "completed_jobs": 0,
317
+ "error_jobs": 0,
318
+ "storage_type": "supabase (error)",
319
+ "error": str(e)
320
+ }
321
+
322
+
323
+ # Global instance
324
+ _storage_instance: Optional[SupabaseJobStorage] = None
325
+
326
+
327
+ def get_job_storage() -> SupabaseJobStorage:
328
+ """Get or create the global job storage instance."""
329
+ global _storage_instance
330
+ if _storage_instance is None:
331
+ _storage_instance = SupabaseJobStorage()
332
+ return _storage_instance