Lev Israel commited on
Commit
1787d7f
·
1 Parent(s): e962d38

Support Cohere

Browse files
Files changed (4) hide show
  1. .claude/settings.local.json +11 -1
  2. app.py +3 -0
  3. models.py +163 -3
  4. requirements.txt +1 -0
.claude/settings.local.json CHANGED
@@ -2,7 +2,17 @@
2
  "permissions": {
3
  "allow": [
4
  "Bash(python -m py_compile:*)",
5
- "Bash(python:*)"
 
 
 
 
 
 
 
 
 
 
6
  ]
7
  }
8
  }
 
2
  "permissions": {
3
  "allow": [
4
  "Bash(python -m py_compile:*)",
5
+ "Bash(python:*)",
6
+ "Bash(grep:*)",
7
+ "Bash(wc:*)",
8
+ "WebSearch",
9
+ "WebFetch(domain:docs.cohere.com)",
10
+ "WebFetch(domain:github.com)",
11
+ "WebFetch(domain:pypi.org)",
12
+ "WebFetch(domain:qdrant.tech)",
13
+ "WebFetch(domain:zilliz.com)",
14
+ "WebFetch(domain:docs.pinecone.io)",
15
+ "WebFetch(domain:cohere.com)"
16
  ]
17
  }
18
  }
app.py CHANGED
@@ -379,6 +379,9 @@ def update_model_inputs_visibility(choice):
379
  elif key_type == "gemini":
380
  label = "Gemini API Key (optional if using gcloud)"
381
  placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}"
 
 
 
382
  else:
383
  label = "OpenAI API Key"
384
  placeholder = f"Enter your OpenAI API key (or set {env_var} env var)"
 
379
  elif key_type == "gemini":
380
  label = "Gemini API Key (optional if using gcloud)"
381
  placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}"
382
+ elif key_type == "cohere":
383
+ label = "Cohere API Key"
384
+ placeholder = f"Enter your Cohere API key (or set {env_var} env var)"
385
  else:
386
  label = "OpenAI API Key"
387
  placeholder = f"Enter your OpenAI API key (or set {env_var} env var)"
models.py CHANGED
@@ -137,6 +137,20 @@ API_MODELS = {
137
  "model_name": "gemini-embedding-001",
138
  "dimensions": 1536,
139
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  }
141
 
142
  # Merge all models for easy lookup
@@ -771,6 +785,140 @@ class GeminiEmbeddingModel(BaseEmbeddingModel):
771
  return self.config.get("description", "")
772
 
773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  def get_curated_model_choices() -> list[tuple[str, str]]:
775
  """
776
  Get list of curated local models for UI dropdown.
@@ -822,7 +970,9 @@ def is_api_model(model_id: str) -> bool:
822
  return True
823
  if model_id.startswith("gemini/"):
824
  return True
825
-
 
 
826
  return False
827
 
828
 
@@ -854,6 +1004,8 @@ def load_model(
854
  return VoyageEmbeddingModel(model_id, api_key=api_key)
855
  elif model_type == "gemini" or model_id.startswith("gemini/"):
856
  return GeminiEmbeddingModel(model_id, api_key=api_key)
 
 
857
  elif model_type == "openai" or model_id.startswith("openai/"):
858
  return OpenAIEmbeddingModel(model_id, api_key=api_key)
859
  else:
@@ -897,7 +1049,11 @@ def validate_model_id(model_id: str) -> tuple[bool, str]:
897
  # Check for Gemini models
898
  if model_id.startswith("gemini/"):
899
  return True, ""
900
-
 
 
 
 
901
  # For custom models, check if it looks like a valid HF model ID
902
  if "/" not in model_id:
903
  return False, "Model ID should be in format 'organization/model-name'"
@@ -944,9 +1100,11 @@ def get_api_key_type(model_id: str) -> Optional[str]:
944
  return "voyage"
945
  elif model_type == "gemini" or model_id.startswith("gemini/"):
946
  return "gemini"
 
 
947
  elif model_type == "openai" or model_id.startswith("openai/"):
948
  return "openai"
949
-
950
  return None
951
 
952
 
@@ -967,6 +1125,8 @@ def get_api_key_env_var(model_id: str) -> Optional[str]:
967
  return "VOYAGE_API_KEY"
968
  elif key_type == "gemini":
969
  return "GEMINI_API_KEY"
 
 
970
  return None
971
 
972
 
 
137
  "model_name": "gemini-embedding-001",
138
  "dimensions": 1536,
139
  },
140
+ "cohere/embed-multilingual-v3.0": {
141
+ "name": "Cohere embed-multilingual-v3.0",
142
+ "description": "Cohere's multilingual embedding model, 100+ languages (API key required)",
143
+ "type": "cohere",
144
+ "model_name": "embed-multilingual-v3.0",
145
+ "dimensions": 1024,
146
+ },
147
+ "cohere/embed-multilingual-light-v3.0": {
148
+ "name": "Cohere embed-multilingual-light-v3.0",
149
+ "description": "Cohere's lightweight multilingual model (API key required)",
150
+ "type": "cohere",
151
+ "model_name": "embed-multilingual-light-v3.0",
152
+ "dimensions": 384,
153
+ },
154
  }
155
 
156
  # Merge all models for easy lookup
 
785
  return self.config.get("description", "")
786
 
787
 
788
+ class CohereEmbeddingModel(BaseEmbeddingModel):
789
+ """
790
+ Wrapper for Cohere embedding API with consistent interface.
791
+ """
792
+
793
+ def __init__(
794
+ self,
795
+ model_id: str,
796
+ api_key: Optional[str] = None,
797
+ ):
798
+ """
799
+ Initialize the Cohere embedding model.
800
+
801
+ Args:
802
+ model_id: Model ID in format 'cohere/model-name'
803
+ api_key: Cohere API key (or uses COHERE_API_KEY env var)
804
+ """
805
+ try:
806
+ import cohere
807
+ except ImportError:
808
+ raise ImportError(
809
+ "Cohere package not installed. Install with: pip install cohere"
810
+ )
811
+
812
+ self.model_id = model_id
813
+
814
+ # Get API key from parameter or environment
815
+ api_key = api_key or os.environ.get("COHERE_API_KEY")
816
+ if not api_key:
817
+ raise ValueError(
818
+ "Cohere API key required. Set COHERE_API_KEY environment variable "
819
+ "or pass api_key parameter."
820
+ )
821
+
822
+ self.client = cohere.Client(api_key=api_key)
823
+
824
+ # Get model config
825
+ self.config = API_MODELS.get(model_id, {
826
+ "name": model_id,
827
+ "description": "Cohere embedding model",
828
+ "type": "cohere",
829
+ "model_name": model_id.replace("cohere/", ""),
830
+ "dimensions": 1024, # Default dimension
831
+ })
832
+
833
+ self._model_name = self.config["model_name"]
834
+ self.embedding_dim = self.config["dimensions"]
835
+
836
+ print(f"Initialized Cohere embedding model: {self._model_name}")
837
+ print(f"Embedding dimension: {self.embedding_dim}")
838
+
839
+ def encode(
840
+ self,
841
+ texts: list[str],
842
+ is_query: bool = False,
843
+ batch_size: int = 96, # Cohere supports up to 96 texts per request
844
+ show_progress: bool = True,
845
+ normalize: bool = True,
846
+ ) -> np.ndarray:
847
+ """
848
+ Encode texts to embeddings using Cohere API.
849
+
850
+ Args:
851
+ texts: List of texts to encode
852
+ is_query: Whether these are queries (uses search_query vs search_document)
853
+ batch_size: Batch size for API calls
854
+ show_progress: Whether to show progress bar
855
+ normalize: Whether to L2-normalize embeddings
856
+
857
+ Returns:
858
+ numpy array of shape (len(texts), embedding_dim)
859
+ """
860
+ import time
861
+
862
+ all_embeddings = []
863
+ total_batches = (len(texts) + batch_size - 1) // batch_size
864
+
865
+ # Cohere v3 models require input_type for asymmetric embeddings
866
+ input_type = "search_query" if is_query else "search_document"
867
+
868
+ for i in range(0, len(texts), batch_size):
869
+ batch = texts[i:i + batch_size]
870
+ batch_num = i // batch_size + 1
871
+
872
+ if show_progress:
873
+ print(f" Encoding batch {batch_num}/{total_batches}...")
874
+
875
+ # Retry logic for API calls
876
+ max_retries = 3
877
+ for attempt in range(max_retries):
878
+ try:
879
+ result = self.client.embed(
880
+ texts=batch,
881
+ model=self._model_name,
882
+ input_type=input_type,
883
+ )
884
+
885
+ # Extract embeddings from response
886
+ batch_embeddings = result.embeddings
887
+ all_embeddings.extend(batch_embeddings)
888
+ break
889
+
890
+ except Exception as e:
891
+ if attempt < max_retries - 1:
892
+ wait_time = 2 ** attempt
893
+ print(f" API error, retrying in {wait_time}s: {e}")
894
+ time.sleep(wait_time)
895
+ else:
896
+ raise RuntimeError(f"Cohere API error after {max_retries} retries: {e}")
897
+
898
+ # Small delay to avoid rate limits
899
+ if i + batch_size < len(texts):
900
+ time.sleep(0.1)
901
+
902
+ embeddings = np.array(all_embeddings, dtype=np.float32)
903
+
904
+ # Normalize if requested
905
+ if normalize:
906
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
907
+ embeddings = embeddings / np.maximum(norms, 1e-10)
908
+
909
+ return embeddings
910
+
911
+ @property
912
+ def name(self) -> str:
913
+ """Get display name for the model."""
914
+ return self.config.get("name", self.model_id)
915
+
916
+ @property
917
+ def description(self) -> str:
918
+ """Get description for the model."""
919
+ return self.config.get("description", "")
920
+
921
+
922
  def get_curated_model_choices() -> list[tuple[str, str]]:
923
  """
924
  Get list of curated local models for UI dropdown.
 
970
  return True
971
  if model_id.startswith("gemini/"):
972
  return True
973
+ if model_id.startswith("cohere/"):
974
+ return True
975
+
976
  return False
977
 
978
 
 
1004
  return VoyageEmbeddingModel(model_id, api_key=api_key)
1005
  elif model_type == "gemini" or model_id.startswith("gemini/"):
1006
  return GeminiEmbeddingModel(model_id, api_key=api_key)
1007
+ elif model_type == "cohere" or model_id.startswith("cohere/"):
1008
+ return CohereEmbeddingModel(model_id, api_key=api_key)
1009
  elif model_type == "openai" or model_id.startswith("openai/"):
1010
  return OpenAIEmbeddingModel(model_id, api_key=api_key)
1011
  else:
 
1049
  # Check for Gemini models
1050
  if model_id.startswith("gemini/"):
1051
  return True, ""
1052
+
1053
+ # Check for Cohere models
1054
+ if model_id.startswith("cohere/"):
1055
+ return True, ""
1056
+
1057
  # For custom models, check if it looks like a valid HF model ID
1058
  if "/" not in model_id:
1059
  return False, "Model ID should be in format 'organization/model-name'"
 
1100
  return "voyage"
1101
  elif model_type == "gemini" or model_id.startswith("gemini/"):
1102
  return "gemini"
1103
+ elif model_type == "cohere" or model_id.startswith("cohere/"):
1104
+ return "cohere"
1105
  elif model_type == "openai" or model_id.startswith("openai/"):
1106
  return "openai"
1107
+
1108
  return None
1109
 
1110
 
 
1125
  return "VOYAGE_API_KEY"
1126
  elif key_type == "gemini":
1127
  return "GEMINI_API_KEY"
1128
+ elif key_type == "cohere":
1129
+ return "COHERE_API_KEY"
1130
  return None
1131
 
1132
 
requirements.txt CHANGED
@@ -19,4 +19,5 @@ openai>=1.0.0
19
  tiktoken>=0.5.0
20
  voyageai>=0.3.0
21
  google-genai>=1.0.0
 
22
 
 
19
  tiktoken>=0.5.0
20
  voyageai>=0.3.0
21
  google-genai>=1.0.0
22
+ cohere>=5.0.0
23