Spaces:
Runtime error
Runtime error
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.exceptions import UnexpectedResponse | |
| from sentence_transformers import SentenceTransformer | |
| from .multimodal_dispatcher import ImageEmbedder, TextEmbedder, TRANSFORMERS_AVAILABLE | |
| import time | |
| import os | |
| import socket | |
| from pathlib import Path | |
| class TopicAwareRetriever: | |
| def __init__(self, qdrant_storage_path="/Users/yufeizhen/Desktop/project/qdrant_storage"): | |
| # Use the same file-based storage path as video_ingester.py | |
| self.qdrant_storage_path = qdrant_storage_path | |
| # Ensure the storage directory exists | |
| os.makedirs(os.path.dirname(self.qdrant_storage_path), exist_ok=True) | |
| # Store client as None initially | |
| self.client = None | |
| # Try to connect | |
| self._connect_to_qdrant() | |
| # Use appropriate embedder based on availability | |
| if TRANSFORMERS_AVAILABLE: | |
| self.embedder = ImageEmbedder() | |
| self.model = None # Not needed with image embedder | |
| else: | |
| # Fallback to text embedder with dimension padding | |
| self.model = SentenceTransformer("all-MiniLM-L6-v2") | |
| self.embedder = None | |
| def _connect_to_qdrant(self): | |
| """Establish connection to Qdrant with fallbacks""" | |
| # Create a direct connection to specified path instead of localhost:6333 | |
| try: | |
| # Try to fix the connection reset issue with ZMQ timeout setting | |
| os.environ["QDRANT_CLIENT_TIMEOUT"] = "30" # 30 second timeout | |
| # Set up client with timeout settings | |
| self.client = QdrantClient( | |
| path=self.qdrant_storage_path, | |
| timeout=30 # 30 second timeout for requests | |
| ) | |
| print("Connected to Qdrant storage at: {}".format(self.qdrant_storage_path)) | |
| # Verify connection with a simple operation | |
| collections = self.client.get_collections() | |
| print("Available collections: {}".format(collections)) | |
| # Check if our collection exists | |
| if self.client.collection_exists("video_chunks"): | |
| count = self.client.count("video_chunks") | |
| print("Found video_chunks collection with {} points".format(count.count)) | |
| else: | |
| print("WARNING: video_chunks collection does not exist - have you ingested videos?") | |
| return True | |
| except Exception as e: | |
| print("Error connecting to local Qdrant storage: {}".format(e)) | |
| # Fall back to the connection singleton if direct connection fails | |
| try: | |
| from llm_engineering.infrastructure.db.qdrant import connection | |
| self.client = connection | |
| print("Using fallback Qdrant connection singleton") | |
| return True | |
| except Exception as e2: | |
| print("Fallback connection also failed: {}".format(e2)) | |
| # Last resort - try localhost connection | |
| try: | |
| self.client = QdrantClient( | |
| host="localhost", | |
| port=6333, | |
| timeout=30 # Add timeout here as well | |
| ) | |
| print("Trying localhost connection") | |
| self.client.get_collections() # Test the connection | |
| print("Connected to Qdrant via localhost") | |
| return True | |
| except Exception as e3: | |
| print("All connection attempts failed: {}".format(e3)) | |
| self.client = None | |
| return False | |
| def _create_fresh_connection(self): | |
| """Create a new connection for each search to avoid connection resets""" | |
| try: | |
| # Close any existing connection | |
| if hasattr(self, 'client') and self.client is not None: | |
| # Try closing if possible (may not work with all client versions) | |
| try: | |
| if hasattr(self.client, 'close'): | |
| self.client.close() | |
| except: | |
| pass | |
| # Create a new one | |
| print("Creating fresh connection to Qdrant...") | |
| return QdrantClient( | |
| path=self.qdrant_storage_path, | |
| timeout=30 # 30 second timeout | |
| ) | |
| except Exception as e: | |
| print("Failed to create fresh connection: {}".format(e)) | |
| return None | |
| def retrieve(self, query: str, k: int=3): | |
| # First check if we have a client | |
| if self.client is None: | |
| print("No Qdrant connection available. Attempting to reconnect...") | |
| if not self._connect_to_qdrant(): | |
| print("Failed to establish Qdrant connection") | |
| return [] | |
| # Use CLIP's text encoder for queries if available, otherwise use SentenceTransformer | |
| if TRANSFORMERS_AVAILABLE and self.embedder: | |
| try: | |
| print("Encoding query with CLIP: '{}'".format(query[:50] + "..." if len(query) > 50 else query)) | |
| query_embedding = self.embedder.encode_text(query) | |
| print("Query embedded successfully") | |
| except Exception as e: | |
| print("Error during query embedding with CLIP: {}".format(e)) | |
| if self.model: | |
| print("Falling back to sentence transformer model") | |
| query_embedding = self._encode_with_sentence_transformer(query) | |
| else: | |
| print("No fallback available, returning empty results") | |
| return [] | |
| else: | |
| # Use sentence-transformers and pad to 512 dimensions for compatibility | |
| query_embedding = self._encode_with_sentence_transformer(query) | |
| # Add retry mechanism for Qdrant search | |
| max_retries = 5 | |
| retry_count = 0 | |
| while retry_count < max_retries: | |
| try: | |
| print("Sending search request to Qdrant (attempt {}/{})".format( | |
| retry_count + 1, max_retries)) | |
| # For each search, create a fresh connection to avoid connection reset issues | |
| fresh_client = self._create_fresh_connection() | |
| if fresh_client is None: | |
| # Fall back to existing client if fresh connection fails | |
| print("Using existing client...") | |
| fresh_client = self.client | |
| # Check if collection exists before searching | |
| if not fresh_client.collection_exists("video_chunks"): | |
| print("ERROR: video_chunks collection doesn't exist in Qdrant") | |
| return [] | |
| # Use shorter timeout for search to avoid hanging | |
| results = fresh_client.search( | |
| collection_name="video_chunks", | |
| query_vector=query_embedding, | |
| limit=k, | |
| with_payload=["start", "end", "video_id", "topics", "text"], | |
| timeout=10 # 10 second timeout just for this search | |
| ) | |
| # If successful, process and return results | |
| print("Search successful, found {} results".format(len(results))) | |
| return self._process_results(results) | |
| except (UnexpectedResponse, ConnectionError, socket.error) as e: | |
| retry_count += 1 | |
| print("Qdrant search error (attempt {}/{}): {}".format( | |
| retry_count, max_retries, e)) | |
| if retry_count >= max_retries: | |
| print("All retry attempts failed, returning empty results") | |
| return [] | |
| # Wait before retrying, with exponential backoff | |
| sleep_time = 2 ** retry_count # Exponential backoff: 2, 4, 8, 16, 32 seconds | |
| print("Waiting {} seconds before retrying...".format(sleep_time)) | |
| time.sleep(sleep_time) | |
| # Try to reconnect with a completely fresh client | |
| print("Creating completely new connection...") | |
| try: | |
| self.client = QdrantClient( | |
| path=self.qdrant_storage_path, | |
| timeout=30 | |
| ) | |
| except Exception as reconnect_error: | |
| print("Reconnection failed: {}".format(reconnect_error)) | |
| except Exception as other_error: | |
| print("Unexpected error during search: {}".format(other_error)) | |
| return [] # Return empty results on any other error | |
| def _encode_with_sentence_transformer(self, query): | |
| """Use sentence transformer with padding/truncation for compatibility""" | |
| try: | |
| print("Using sentence-transformer for query embedding") | |
| embed = self.model.encode(query) | |
| if len(embed) < 512: | |
| print("Padding embedding from {} to 512 dimensions".format(len(embed))) | |
| query_embedding = embed.tolist() + [0.0] * (512 - len(embed)) | |
| elif len(embed) > 512: | |
| print("Truncating embedding from {} to 512 dimensions".format(len(embed))) | |
| query_embedding = embed[:512].tolist() | |
| else: | |
| query_embedding = embed.tolist() | |
| return query_embedding | |
| except Exception as e: | |
| print("Error encoding with sentence transformer: {}".format(e)) | |
| # Return a zero vector as last resort | |
| return [0.0] * 512 | |
| def _process_results(self, results): | |
| if not results: | |
| return [] | |
| clips = [] | |
| for hit in results: | |
| payload = hit.payload | |
| clips.append({ | |
| "video_id": payload["video_id"], | |
| "start": payload["start"], | |
| "end": payload["end"], | |
| "score": hit.score, | |
| "text": payload.get("text", ""), # Add text content for debugging | |
| "topics": payload.get("topics", []) | |
| }) | |
| return clips | |