Spaces:
Configuration error
Configuration error
oremaz
commited on
Commit
·
1a3b775
1
Parent(s):
6aad47c
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -200,64 +200,51 @@ def initialize_models(use_api_mode=False):
|
|
| 200 |
|
| 201 |
from typing import Any, List, Optional
|
| 202 |
from llama_index.core.embeddings import BaseEmbedding
|
| 203 |
-
import
|
| 204 |
-
from
|
| 205 |
|
| 206 |
-
class
|
| 207 |
"""
|
| 208 |
-
|
| 209 |
"""
|
| 210 |
|
| 211 |
-
def __init__(self,
|
| 212 |
-
model_name_bge: str = "BAAI/bge-base-en-v1.5",
|
| 213 |
-
model_weight_path: str = "path/to/Visualized_base_en_v1.5.pth",
|
| 214 |
-
**kwargs: Any) -> None:
|
| 215 |
super().__init__(**kwargs)
|
| 216 |
-
|
| 217 |
-
self._model = Visualized_BGE(
|
| 218 |
-
model_name_bge=model_name_bge,
|
| 219 |
-
model_weight=model_weight_path
|
| 220 |
-
)
|
| 221 |
-
self._model.eval()
|
| 222 |
|
| 223 |
@classmethod
|
| 224 |
def class_name(cls) -> str:
|
| 225 |
-
return "
|
| 226 |
|
| 227 |
def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
embedding = self._model.encode(text=query)
|
| 236 |
-
return embedding.cpu().numpy().tolist()
|
| 237 |
|
| 238 |
def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
embedding = self._model.encode(text=text)
|
| 247 |
-
return embedding.cpu().numpy().tolist()
|
| 248 |
|
| 249 |
def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
|
| 250 |
-
"""Batch embedding generation."""
|
| 251 |
embeddings = []
|
| 252 |
image_paths = image_paths or [None] * len(texts)
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
|
| 262 |
return embeddings
|
| 263 |
|
|
@@ -268,7 +255,7 @@ def initialize_models(use_api_mode=False):
|
|
| 268 |
return self._get_text_embedding(text, image_path)
|
| 269 |
|
| 270 |
|
| 271 |
-
embed_model =
|
| 272 |
# Code LLM
|
| 273 |
code_llm = HuggingFaceLLM(
|
| 274 |
model_name="Qwen/Qwen2.5-Coder-3B-Instruct",
|
|
|
|
| 200 |
|
| 201 |
from typing import Any, List, Optional
|
| 202 |
from llama_index.core.embeddings import BaseEmbedding
|
| 203 |
+
from sentence_transformers import SentenceTransformer
|
| 204 |
+
from PIL import Image
|
| 205 |
|
| 206 |
+
class MultimodalCLIPEmbedding(BaseEmbedding):
|
| 207 |
"""
|
| 208 |
+
Custom embedding class using CLIP for multimodal capabilities.
|
| 209 |
"""
|
| 210 |
|
| 211 |
+
def __init__(self, model_name: str = "clip-ViT-B-32", **kwargs: Any) -> None:
|
|
|
|
|
|
|
|
|
|
| 212 |
super().__init__(**kwargs)
|
| 213 |
+
self._model = SentenceTransformer(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
@classmethod
|
| 216 |
def class_name(cls) -> str:
|
| 217 |
+
return "multimodal_clip"
|
| 218 |
|
| 219 |
def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
| 220 |
+
if image_path:
|
| 221 |
+
image = Image.open(image_path)
|
| 222 |
+
embedding = self._model.encode(image)
|
| 223 |
+
return embedding.tolist()
|
| 224 |
+
else:
|
| 225 |
+
embedding = self._model.encode(query)
|
| 226 |
+
return embedding.tolist()
|
|
|
|
|
|
|
| 227 |
|
| 228 |
def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
| 229 |
+
if image_path:
|
| 230 |
+
image = Image.open(image_path)
|
| 231 |
+
embedding = self._model.encode(image)
|
| 232 |
+
return embedding.tolist()
|
| 233 |
+
else:
|
| 234 |
+
embedding = self._model.encode(text)
|
| 235 |
+
return embedding.tolist()
|
|
|
|
|
|
|
| 236 |
|
| 237 |
def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
|
|
|
|
| 238 |
embeddings = []
|
| 239 |
image_paths = image_paths or [None] * len(texts)
|
| 240 |
|
| 241 |
+
for text, img_path in zip(texts, image_paths):
|
| 242 |
+
if img_path:
|
| 243 |
+
image = Image.open(img_path)
|
| 244 |
+
emb = self._model.encode(image)
|
| 245 |
+
else:
|
| 246 |
+
emb = self._model.encode(text)
|
| 247 |
+
embeddings.append(emb.tolist())
|
| 248 |
|
| 249 |
return embeddings
|
| 250 |
|
|
|
|
| 255 |
return self._get_text_embedding(text, image_path)
|
| 256 |
|
| 257 |
|
| 258 |
+
embed_model = MultimodalCLIPEmbedding()
|
| 259 |
# Code LLM
|
| 260 |
code_llm = HuggingFaceLLM(
|
| 261 |
model_name="Qwen/Qwen2.5-Coder-3B-Instruct",
|