Spaces:
Configuration error
Configuration error
oremaz
commited on
Commit
·
c02a77b
1
Parent(s):
798378e
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -198,16 +198,82 @@ def initialize_models(use_api_mode=False):
|
|
| 198 |
|
| 199 |
proj_llm = QwenVL7BCustomLLM()
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
# Code LLM
|
| 207 |
code_llm = HuggingFaceLLM(
|
| 208 |
-
model_name="Qwen/Qwen2.5-Coder-
|
| 209 |
-
tokenizer_name="Qwen/Qwen2.5-Coder-
|
| 210 |
-
device_map="
|
| 211 |
model_kwargs={"torch_dtype": "auto"},
|
| 212 |
generate_kwargs={"do_sample": False}
|
| 213 |
)
|
|
@@ -896,8 +962,31 @@ async def main():
|
|
| 896 |
}
|
| 897 |
|
| 898 |
print(question_data)
|
| 899 |
-
|
| 900 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
#answer = await agent.solve_gaia_question(question_data)
|
| 902 |
#print(f"Answer: {answer}")
|
| 903 |
|
|
|
|
| 198 |
|
| 199 |
proj_llm = QwenVL7BCustomLLM()
|
| 200 |
|
| 201 |
+
from typing import Any, List, Optional
|
| 202 |
+
from llama_index.core.embeddings import BaseEmbedding
|
| 203 |
+
import torch
|
| 204 |
+
from FlagEmbedding.visual.modeling import Visualized_BGE
|
| 205 |
+
|
| 206 |
+
class BAAIVisualizedAdvanced(BaseEmbedding):
|
| 207 |
+
"""
|
| 208 |
+
Advanced implementation using FlagEmbedding's Visualized_BGE.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
def __init__(self,
|
| 212 |
+
model_name_bge: str = "BAAI/bge-base-en-v1.5",
|
| 213 |
+
model_weight_path: str = "path/to/Visualized_base_en_v1.5.pth",
|
| 214 |
+
**kwargs: Any) -> None:
|
| 215 |
+
super().__init__(**kwargs)
|
| 216 |
+
# Initialize the Visualized BGE model
|
| 217 |
+
self._model = Visualized_BGE(
|
| 218 |
+
model_name_bge=model_name_bge,
|
| 219 |
+
model_weight=model_weight_path
|
| 220 |
+
)
|
| 221 |
+
self._model.eval()
|
| 222 |
+
|
| 223 |
+
@classmethod
|
| 224 |
+
def class_name(cls) -> str:
|
| 225 |
+
return "baai_visualized_advanced"
|
| 226 |
+
|
| 227 |
+
def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
| 228 |
+
"""Generate embedding for query with optional image."""
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
if image_path:
|
| 231 |
+
# Encode both text and image
|
| 232 |
+
embedding = self._model.encode(image=image_path, text=query)
|
| 233 |
+
else:
|
| 234 |
+
# Text-only encoding
|
| 235 |
+
embedding = self._model.encode(text=query)
|
| 236 |
+
return embedding.cpu().numpy().tolist()
|
| 237 |
+
|
| 238 |
+
def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
| 239 |
+
"""Generate embedding for text with optional image."""
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
if image_path:
|
| 242 |
+
# Image-only encoding
|
| 243 |
+
embedding = self._model.encode(image=image_path)
|
| 244 |
+
else:
|
| 245 |
+
# Text-only encoding
|
| 246 |
+
embedding = self._model.encode(text=text)
|
| 247 |
+
return embedding.cpu().numpy().tolist()
|
| 248 |
+
|
| 249 |
+
def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
|
| 250 |
+
"""Batch embedding generation."""
|
| 251 |
+
embeddings = []
|
| 252 |
+
image_paths = image_paths or [None] * len(texts)
|
| 253 |
+
|
| 254 |
+
with torch.no_grad():
|
| 255 |
+
for text, img_path in zip(texts, image_paths):
|
| 256 |
+
if img_path:
|
| 257 |
+
emb = self._model.encode(image=img_path, text=text)
|
| 258 |
+
else:
|
| 259 |
+
emb = self._model.encode(text=text)
|
| 260 |
+
embeddings.append(emb.cpu().numpy().tolist())
|
| 261 |
+
|
| 262 |
+
return embeddings
|
| 263 |
+
|
| 264 |
+
async def _aget_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
| 265 |
+
return self._get_query_embedding(query, image_path)
|
| 266 |
+
|
| 267 |
+
async def _aget_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
| 268 |
+
return self._get_text_embedding(text, image_path)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
embed_model = BAAIVisualizedEmbedding()
|
| 272 |
# Code LLM
|
| 273 |
code_llm = HuggingFaceLLM(
|
| 274 |
+
model_name="Qwen/Qwen2.5-Coder-3B-Instruct",
|
| 275 |
+
tokenizer_name="Qwen/Qwen2.5-Coder-3B-Instruct",
|
| 276 |
+
device_map="auto",
|
| 277 |
model_kwargs={"torch_dtype": "auto"},
|
| 278 |
generate_kwargs={"do_sample": False}
|
| 279 |
)
|
|
|
|
| 962 |
}
|
| 963 |
|
| 964 |
print(question_data)
|
| 965 |
+
proj_llm, code_llm, embed_model = initialize_models(use_api_mode=False)
|
| 966 |
+
|
| 967 |
+
# Test with image
|
| 968 |
+
file_path = "test_image.jpg"
|
| 969 |
+
|
| 970 |
+
# Test proj_llm (multimodal LLM)
|
| 971 |
+
response = proj_llm.complete(
|
| 972 |
+
prompt="Describe what you see in this image.",
|
| 973 |
+
image_paths=[file_path]
|
| 974 |
+
)
|
| 975 |
+
print(f"LLM Response: {response.text}")
|
| 976 |
+
|
| 977 |
+
# Test embed_model with image
|
| 978 |
+
image_embedding = embed_model._get_text_embedding("", image_path=file_path)
|
| 979 |
+
print(f"Image embedding dimension: {len(image_embedding)}")
|
| 980 |
+
print(f"First 5 elements: {image_embedding[:5]}")
|
| 981 |
+
|
| 982 |
+
# Test embed_model with text
|
| 983 |
+
text_embedding = embed_model._get_text_embedding("A red sports car")
|
| 984 |
+
print(f"Text embedding dimension: {len(text_embedding)}")
|
| 985 |
+
|
| 986 |
+
# Test multimodal embedding (text + image)
|
| 987 |
+
multimodal_embedding = embed_model._get_query_embedding("red car", image_path=file_path)
|
| 988 |
+
print(f"Multimodal embedding dimension: {len(multimodal_embedding)}")
|
| 989 |
+
|
| 990 |
#answer = await agent.solve_gaia_question(question_data)
|
| 991 |
#print(f"Answer: {answer}")
|
| 992 |
|