Spaces:
Sleeping
Sleeping
Lev Israel commited on
Commit ·
1787d7f
1
Parent(s): e962d38
Support Cohere
Browse files- .claude/settings.local.json +11 -1
- app.py +3 -0
- models.py +163 -3
- requirements.txt +1 -0
.claude/settings.local.json
CHANGED
|
@@ -2,7 +2,17 @@
|
|
| 2 |
"permissions": {
|
| 3 |
"allow": [
|
| 4 |
"Bash(python -m py_compile:*)",
|
| 5 |
-
"Bash(python:*)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
]
|
| 7 |
}
|
| 8 |
}
|
|
|
|
| 2 |
"permissions": {
|
| 3 |
"allow": [
|
| 4 |
"Bash(python -m py_compile:*)",
|
| 5 |
+
"Bash(python:*)",
|
| 6 |
+
"Bash(grep:*)",
|
| 7 |
+
"Bash(wc:*)",
|
| 8 |
+
"WebSearch",
|
| 9 |
+
"WebFetch(domain:docs.cohere.com)",
|
| 10 |
+
"WebFetch(domain:github.com)",
|
| 11 |
+
"WebFetch(domain:pypi.org)",
|
| 12 |
+
"WebFetch(domain:qdrant.tech)",
|
| 13 |
+
"WebFetch(domain:zilliz.com)",
|
| 14 |
+
"WebFetch(domain:docs.pinecone.io)",
|
| 15 |
+
"WebFetch(domain:cohere.com)"
|
| 16 |
]
|
| 17 |
}
|
| 18 |
}
|
app.py
CHANGED
|
@@ -379,6 +379,9 @@ def update_model_inputs_visibility(choice):
|
|
| 379 |
elif key_type == "gemini":
|
| 380 |
label = "Gemini API Key (optional if using gcloud)"
|
| 381 |
placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}"
|
|
|
|
|
|
|
|
|
|
| 382 |
else:
|
| 383 |
label = "OpenAI API Key"
|
| 384 |
placeholder = f"Enter your OpenAI API key (or set {env_var} env var)"
|
|
|
|
| 379 |
elif key_type == "gemini":
|
| 380 |
label = "Gemini API Key (optional if using gcloud)"
|
| 381 |
placeholder = f"Leave blank if using gcloud ADC, or enter API key / set {env_var}"
|
| 382 |
+
elif key_type == "cohere":
|
| 383 |
+
label = "Cohere API Key"
|
| 384 |
+
placeholder = f"Enter your Cohere API key (or set {env_var} env var)"
|
| 385 |
else:
|
| 386 |
label = "OpenAI API Key"
|
| 387 |
placeholder = f"Enter your OpenAI API key (or set {env_var} env var)"
|
models.py
CHANGED
|
@@ -137,6 +137,20 @@ API_MODELS = {
|
|
| 137 |
"model_name": "gemini-embedding-001",
|
| 138 |
"dimensions": 1536,
|
| 139 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
}
|
| 141 |
|
| 142 |
# Merge all models for easy lookup
|
|
@@ -771,6 +785,140 @@ class GeminiEmbeddingModel(BaseEmbeddingModel):
|
|
| 771 |
return self.config.get("description", "")
|
| 772 |
|
| 773 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
def get_curated_model_choices() -> list[tuple[str, str]]:
|
| 775 |
"""
|
| 776 |
Get list of curated local models for UI dropdown.
|
|
@@ -822,7 +970,9 @@ def is_api_model(model_id: str) -> bool:
|
|
| 822 |
return True
|
| 823 |
if model_id.startswith("gemini/"):
|
| 824 |
return True
|
| 825 |
-
|
|
|
|
|
|
|
| 826 |
return False
|
| 827 |
|
| 828 |
|
|
@@ -854,6 +1004,8 @@ def load_model(
|
|
| 854 |
return VoyageEmbeddingModel(model_id, api_key=api_key)
|
| 855 |
elif model_type == "gemini" or model_id.startswith("gemini/"):
|
| 856 |
return GeminiEmbeddingModel(model_id, api_key=api_key)
|
|
|
|
|
|
|
| 857 |
elif model_type == "openai" or model_id.startswith("openai/"):
|
| 858 |
return OpenAIEmbeddingModel(model_id, api_key=api_key)
|
| 859 |
else:
|
|
@@ -897,7 +1049,11 @@ def validate_model_id(model_id: str) -> tuple[bool, str]:
|
|
| 897 |
# Check for Gemini models
|
| 898 |
if model_id.startswith("gemini/"):
|
| 899 |
return True, ""
|
| 900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
# For custom models, check if it looks like a valid HF model ID
|
| 902 |
if "/" not in model_id:
|
| 903 |
return False, "Model ID should be in format 'organization/model-name'"
|
|
@@ -944,9 +1100,11 @@ def get_api_key_type(model_id: str) -> Optional[str]:
|
|
| 944 |
return "voyage"
|
| 945 |
elif model_type == "gemini" or model_id.startswith("gemini/"):
|
| 946 |
return "gemini"
|
|
|
|
|
|
|
| 947 |
elif model_type == "openai" or model_id.startswith("openai/"):
|
| 948 |
return "openai"
|
| 949 |
-
|
| 950 |
return None
|
| 951 |
|
| 952 |
|
|
@@ -967,6 +1125,8 @@ def get_api_key_env_var(model_id: str) -> Optional[str]:
|
|
| 967 |
return "VOYAGE_API_KEY"
|
| 968 |
elif key_type == "gemini":
|
| 969 |
return "GEMINI_API_KEY"
|
|
|
|
|
|
|
| 970 |
return None
|
| 971 |
|
| 972 |
|
|
|
|
| 137 |
"model_name": "gemini-embedding-001",
|
| 138 |
"dimensions": 1536,
|
| 139 |
},
|
| 140 |
+
"cohere/embed-multilingual-v3.0": {
|
| 141 |
+
"name": "Cohere embed-multilingual-v3.0",
|
| 142 |
+
"description": "Cohere's multilingual embedding model, 100+ languages (API key required)",
|
| 143 |
+
"type": "cohere",
|
| 144 |
+
"model_name": "embed-multilingual-v3.0",
|
| 145 |
+
"dimensions": 1024,
|
| 146 |
+
},
|
| 147 |
+
"cohere/embed-multilingual-light-v3.0": {
|
| 148 |
+
"name": "Cohere embed-multilingual-light-v3.0",
|
| 149 |
+
"description": "Cohere's lightweight multilingual model (API key required)",
|
| 150 |
+
"type": "cohere",
|
| 151 |
+
"model_name": "embed-multilingual-light-v3.0",
|
| 152 |
+
"dimensions": 384,
|
| 153 |
+
},
|
| 154 |
}
|
| 155 |
|
| 156 |
# Merge all models for easy lookup
|
|
|
|
| 785 |
return self.config.get("description", "")
|
| 786 |
|
| 787 |
|
| 788 |
+
class CohereEmbeddingModel(BaseEmbeddingModel):
|
| 789 |
+
"""
|
| 790 |
+
Wrapper for Cohere embedding API with consistent interface.
|
| 791 |
+
"""
|
| 792 |
+
|
| 793 |
+
def __init__(
|
| 794 |
+
self,
|
| 795 |
+
model_id: str,
|
| 796 |
+
api_key: Optional[str] = None,
|
| 797 |
+
):
|
| 798 |
+
"""
|
| 799 |
+
Initialize the Cohere embedding model.
|
| 800 |
+
|
| 801 |
+
Args:
|
| 802 |
+
model_id: Model ID in format 'cohere/model-name'
|
| 803 |
+
api_key: Cohere API key (or uses COHERE_API_KEY env var)
|
| 804 |
+
"""
|
| 805 |
+
try:
|
| 806 |
+
import cohere
|
| 807 |
+
except ImportError:
|
| 808 |
+
raise ImportError(
|
| 809 |
+
"Cohere package not installed. Install with: pip install cohere"
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
self.model_id = model_id
|
| 813 |
+
|
| 814 |
+
# Get API key from parameter or environment
|
| 815 |
+
api_key = api_key or os.environ.get("COHERE_API_KEY")
|
| 816 |
+
if not api_key:
|
| 817 |
+
raise ValueError(
|
| 818 |
+
"Cohere API key required. Set COHERE_API_KEY environment variable "
|
| 819 |
+
"or pass api_key parameter."
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
self.client = cohere.Client(api_key=api_key)
|
| 823 |
+
|
| 824 |
+
# Get model config
|
| 825 |
+
self.config = API_MODELS.get(model_id, {
|
| 826 |
+
"name": model_id,
|
| 827 |
+
"description": "Cohere embedding model",
|
| 828 |
+
"type": "cohere",
|
| 829 |
+
"model_name": model_id.replace("cohere/", ""),
|
| 830 |
+
"dimensions": 1024, # Default dimension
|
| 831 |
+
})
|
| 832 |
+
|
| 833 |
+
self._model_name = self.config["model_name"]
|
| 834 |
+
self.embedding_dim = self.config["dimensions"]
|
| 835 |
+
|
| 836 |
+
print(f"Initialized Cohere embedding model: {self._model_name}")
|
| 837 |
+
print(f"Embedding dimension: {self.embedding_dim}")
|
| 838 |
+
|
| 839 |
+
def encode(
|
| 840 |
+
self,
|
| 841 |
+
texts: list[str],
|
| 842 |
+
is_query: bool = False,
|
| 843 |
+
batch_size: int = 96, # Cohere supports up to 96 texts per request
|
| 844 |
+
show_progress: bool = True,
|
| 845 |
+
normalize: bool = True,
|
| 846 |
+
) -> np.ndarray:
|
| 847 |
+
"""
|
| 848 |
+
Encode texts to embeddings using Cohere API.
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
texts: List of texts to encode
|
| 852 |
+
is_query: Whether these are queries (uses search_query vs search_document)
|
| 853 |
+
batch_size: Batch size for API calls
|
| 854 |
+
show_progress: Whether to show progress bar
|
| 855 |
+
normalize: Whether to L2-normalize embeddings
|
| 856 |
+
|
| 857 |
+
Returns:
|
| 858 |
+
numpy array of shape (len(texts), embedding_dim)
|
| 859 |
+
"""
|
| 860 |
+
import time
|
| 861 |
+
|
| 862 |
+
all_embeddings = []
|
| 863 |
+
total_batches = (len(texts) + batch_size - 1) // batch_size
|
| 864 |
+
|
| 865 |
+
# Cohere v3 models require input_type for asymmetric embeddings
|
| 866 |
+
input_type = "search_query" if is_query else "search_document"
|
| 867 |
+
|
| 868 |
+
for i in range(0, len(texts), batch_size):
|
| 869 |
+
batch = texts[i:i + batch_size]
|
| 870 |
+
batch_num = i // batch_size + 1
|
| 871 |
+
|
| 872 |
+
if show_progress:
|
| 873 |
+
print(f" Encoding batch {batch_num}/{total_batches}...")
|
| 874 |
+
|
| 875 |
+
# Retry logic for API calls
|
| 876 |
+
max_retries = 3
|
| 877 |
+
for attempt in range(max_retries):
|
| 878 |
+
try:
|
| 879 |
+
result = self.client.embed(
|
| 880 |
+
texts=batch,
|
| 881 |
+
model=self._model_name,
|
| 882 |
+
input_type=input_type,
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
# Extract embeddings from response
|
| 886 |
+
batch_embeddings = result.embeddings
|
| 887 |
+
all_embeddings.extend(batch_embeddings)
|
| 888 |
+
break
|
| 889 |
+
|
| 890 |
+
except Exception as e:
|
| 891 |
+
if attempt < max_retries - 1:
|
| 892 |
+
wait_time = 2 ** attempt
|
| 893 |
+
print(f" API error, retrying in {wait_time}s: {e}")
|
| 894 |
+
time.sleep(wait_time)
|
| 895 |
+
else:
|
| 896 |
+
raise RuntimeError(f"Cohere API error after {max_retries} retries: {e}")
|
| 897 |
+
|
| 898 |
+
# Small delay to avoid rate limits
|
| 899 |
+
if i + batch_size < len(texts):
|
| 900 |
+
time.sleep(0.1)
|
| 901 |
+
|
| 902 |
+
embeddings = np.array(all_embeddings, dtype=np.float32)
|
| 903 |
+
|
| 904 |
+
# Normalize if requested
|
| 905 |
+
if normalize:
|
| 906 |
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
| 907 |
+
embeddings = embeddings / np.maximum(norms, 1e-10)
|
| 908 |
+
|
| 909 |
+
return embeddings
|
| 910 |
+
|
| 911 |
+
@property
|
| 912 |
+
def name(self) -> str:
|
| 913 |
+
"""Get display name for the model."""
|
| 914 |
+
return self.config.get("name", self.model_id)
|
| 915 |
+
|
| 916 |
+
@property
|
| 917 |
+
def description(self) -> str:
|
| 918 |
+
"""Get description for the model."""
|
| 919 |
+
return self.config.get("description", "")
|
| 920 |
+
|
| 921 |
+
|
| 922 |
def get_curated_model_choices() -> list[tuple[str, str]]:
|
| 923 |
"""
|
| 924 |
Get list of curated local models for UI dropdown.
|
|
|
|
| 970 |
return True
|
| 971 |
if model_id.startswith("gemini/"):
|
| 972 |
return True
|
| 973 |
+
if model_id.startswith("cohere/"):
|
| 974 |
+
return True
|
| 975 |
+
|
| 976 |
return False
|
| 977 |
|
| 978 |
|
|
|
|
| 1004 |
return VoyageEmbeddingModel(model_id, api_key=api_key)
|
| 1005 |
elif model_type == "gemini" or model_id.startswith("gemini/"):
|
| 1006 |
return GeminiEmbeddingModel(model_id, api_key=api_key)
|
| 1007 |
+
elif model_type == "cohere" or model_id.startswith("cohere/"):
|
| 1008 |
+
return CohereEmbeddingModel(model_id, api_key=api_key)
|
| 1009 |
elif model_type == "openai" or model_id.startswith("openai/"):
|
| 1010 |
return OpenAIEmbeddingModel(model_id, api_key=api_key)
|
| 1011 |
else:
|
|
|
|
| 1049 |
# Check for Gemini models
|
| 1050 |
if model_id.startswith("gemini/"):
|
| 1051 |
return True, ""
|
| 1052 |
+
|
| 1053 |
+
# Check for Cohere models
|
| 1054 |
+
if model_id.startswith("cohere/"):
|
| 1055 |
+
return True, ""
|
| 1056 |
+
|
| 1057 |
# For custom models, check if it looks like a valid HF model ID
|
| 1058 |
if "/" not in model_id:
|
| 1059 |
return False, "Model ID should be in format 'organization/model-name'"
|
|
|
|
| 1100 |
return "voyage"
|
| 1101 |
elif model_type == "gemini" or model_id.startswith("gemini/"):
|
| 1102 |
return "gemini"
|
| 1103 |
+
elif model_type == "cohere" or model_id.startswith("cohere/"):
|
| 1104 |
+
return "cohere"
|
| 1105 |
elif model_type == "openai" or model_id.startswith("openai/"):
|
| 1106 |
return "openai"
|
| 1107 |
+
|
| 1108 |
return None
|
| 1109 |
|
| 1110 |
|
|
|
|
| 1125 |
return "VOYAGE_API_KEY"
|
| 1126 |
elif key_type == "gemini":
|
| 1127 |
return "GEMINI_API_KEY"
|
| 1128 |
+
elif key_type == "cohere":
|
| 1129 |
+
return "COHERE_API_KEY"
|
| 1130 |
return None
|
| 1131 |
|
| 1132 |
|
requirements.txt
CHANGED
|
@@ -19,4 +19,5 @@ openai>=1.0.0
|
|
| 19 |
tiktoken>=0.5.0
|
| 20 |
voyageai>=0.3.0
|
| 21 |
google-genai>=1.0.0
|
|
|
|
| 22 |
|
|
|
|
| 19 |
tiktoken>=0.5.0
|
| 20 |
voyageai>=0.3.0
|
| 21 |
google-genai>=1.0.0
|
| 22 |
+
cohere>=5.0.0
|
| 23 |
|