akryldigital commited on
Commit
99c582d
Β·
verified Β·
1 Parent(s): 89199c5

add gpu mapping

Browse files
Files changed (1) hide show
  1. src/retrieval/context.py +12 -4
src/retrieval/context.py CHANGED
@@ -12,6 +12,9 @@ import torch
12
  import numpy as np
13
  from qdrant_client.http import models as rest
14
 
 
 
 
15
 
16
  try:
17
  from langchain.docstore.document import Document
@@ -57,10 +60,12 @@ class ContextRetriever:
57
  from colbert.infra import Run, ColBERTConfig
58
  from colbert.modeling.checkpoint import Checkpoint
59
  # ColBERT uses late interaction - different implementation needed
 
60
  print(f"βœ… RERANKER: ColBERT model detected ({self.reranker_model_name})")
61
  print(f"πŸ” INTERACTION TYPE: Late interaction (token-level embeddings)")
 
62
 
63
- # Create ColBERT config for CPU mode
64
  colbert_config = ColBERTConfig(
65
  doc_maxlen=300,
66
  query_maxlen=32,
@@ -72,15 +77,18 @@ class ContextRetriever:
72
  # Load checkpoint (e.g. "colbert-ir/colbertv2.0")
73
  self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config)
74
  self.colbert_model = self.colbert_checkpoint.model
 
 
75
  self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer
76
  self.reranker = self._colbert_rerank # attach wrapper function
77
- print(f"βœ… COLBERT: Model and tokenizer loaded successfully")
78
 
79
  else:
80
  # Standard CrossEncoder for BGE and other models
81
  from sentence_transformers import CrossEncoder
82
- self.reranker = CrossEncoder(self.reranker_model_name)
83
- print(f"βœ… RERANKER: Initialized {self.reranker_model_name}")
 
84
  print(f"πŸ” INTERACTION TYPE: Cross-encoder (single relevance score)")
85
  except Exception as e:
86
  print(f"⚠️ Reranker initialization failed: {e}")
 
12
  import numpy as np
13
  from qdrant_client.http import models as rest
14
 
15
+ # Import device detection utility
16
+ from src.utils.device import get_device_for_sentence_transformers
17
+
18
 
19
  try:
20
  from langchain.docstore.document import Document
 
60
  from colbert.infra import Run, ColBERTConfig
61
  from colbert.modeling.checkpoint import Checkpoint
62
  # ColBERT uses late interaction - different implementation needed
63
+ device = get_device_for_sentence_transformers()
64
  print(f"βœ… RERANKER: ColBERT model detected ({self.reranker_model_name})")
65
  print(f"πŸ” INTERACTION TYPE: Late interaction (token-level embeddings)")
66
+ print(f"πŸ–₯️ DEVICE: {device}")
67
 
68
+ # Create ColBERT config with device
69
  colbert_config = ColBERTConfig(
70
  doc_maxlen=300,
71
  query_maxlen=32,
 
77
  # Load checkpoint (e.g. "colbert-ir/colbertv2.0")
78
  self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config)
79
  self.colbert_model = self.colbert_checkpoint.model
80
+ # Move model to device
81
+ self.colbert_model = self.colbert_model.to(device)
82
  self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer
83
  self.reranker = self._colbert_rerank # attach wrapper function
84
+ print(f"βœ… COLBERT: Model and tokenizer loaded successfully on {device}")
85
 
86
  else:
87
  # Standard CrossEncoder for BGE and other models
88
  from sentence_transformers import CrossEncoder
89
+ device = get_device_for_sentence_transformers()
90
+ self.reranker = CrossEncoder(self.reranker_model_name, device=device)
91
+ print(f"βœ… RERANKER: Initialized {self.reranker_model_name} on {device}")
92
  print(f"πŸ” INTERACTION TYPE: Cross-encoder (single relevance score)")
93
  except Exception as e:
94
  print(f"⚠️ Reranker initialization failed: {e}")