Spaces:
Sleeping
Sleeping
Commit
·
549c270
0
Parent(s):
Deploy: Minimal FastAPI backend for CoVE Space
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +1 -0
- data/processed/beauty/index/.ipynb_checkpoints/defaults-checkpoint.json +9 -0
- data/processed/beauty/index/.ipynb_checkpoints/defaults_cove-checkpoint.json +0 -0
- data/processed/beauty/index/defaults.json +16 -0
- data/processed/beauty/index/defaults_cove.json +22 -0
- requirements.txt +13 -0
- space.yaml +5 -0
- src/__init__.py +2 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/agents/.ipynb_checkpoints/agent_types-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/base-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/chat_agent-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/data_agent-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/index_agent-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/model_agent-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/orchestrator-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/recommend_agent-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/report_agent-checkpoint.py +0 -0
- src/agents/.ipynb_checkpoints/run_agent-checkpoint.py +0 -0
- src/agents/__init__.py +1 -0
- src/agents/__pycache__/__init__.cpython-311.pyc +0 -0
- src/agents/__pycache__/chat_agent.cpython-311.pyc +0 -0
- src/agents/__pycache__/orchestrator.cpython-311.pyc +0 -0
- src/agents/__pycache__/report_agent.cpython-311.pyc +0 -0
- src/agents/agent_types.py +16 -0
- src/agents/base.py +16 -0
- src/agents/chat_agent.py +311 -0
- src/agents/data_agent.py +46 -0
- src/agents/index_agent.py +34 -0
- src/agents/model_agent.py +8 -0
- src/agents/orchestrator.py +44 -0
- src/agents/recommend_agent.py +37 -0
- src/agents/report_agent.py +319 -0
- src/agents/run_agent.py +28 -0
- src/cove/.ipynb_checkpoints/__init__-checkpoint.py +0 -0
- src/cove/.ipynb_checkpoints/fuse_index-checkpoint.py +0 -0
- src/cove/.ipynb_checkpoints/io-checkpoint.py +0 -0
- src/cove/__init__.py +0 -0
- src/cove/fuse_index.py +106 -0
- src/cove/io.py +29 -0
- src/data/.ipynb_checkpoints/init-checkpoint.py +0 -0
- src/data/.ipynb_checkpoints/loader-checkpoint.py +0 -0
- src/data/.ipynb_checkpoints/registry-checkpoint.py +0 -0
- src/data/__init__.py +2 -0
- src/data/__pycache__/__init__.cpython-311.pyc +0 -0
- src/data/__pycache__/loader.cpython-311.pyc +0 -0
- src/data/__pycache__/registry.cpython-311.pyc +0 -0
- src/data/loader.py +15 -0
- src/data/registry.py +73 -0
- src/models/.ipynb_checkpoints/fusion-checkpoint.py +0 -0
app.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from api.app_api import app
|
data/processed/beauty/index/.ipynb_checkpoints/defaults-checkpoint.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"weighted": {
|
| 3 |
+
"w_text": 1.0,
|
| 4 |
+
"w_image": 0.0,
|
| 5 |
+
"w_meta": 0.2,
|
| 6 |
+
"k": 10,
|
| 7 |
+
"faiss_name": "weighted_wt1.0_wi0.0_wm0.2"
|
| 8 |
+
}
|
| 9 |
+
}
|
data/processed/beauty/index/.ipynb_checkpoints/defaults_cove-checkpoint.json
ADDED
|
File without changes
|
data/processed/beauty/index/defaults.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"concat": {
|
| 3 |
+
"w_text": 1.0,
|
| 4 |
+
"w_image": 0.2,
|
| 5 |
+
"w_meta": 0.2,
|
| 6 |
+
"k": 10,
|
| 7 |
+
"faiss_name": "beauty_concat"
|
| 8 |
+
},
|
| 9 |
+
"weighted": {
|
| 10 |
+
"w_text": 1.0,
|
| 11 |
+
"w_image": 0.2,
|
| 12 |
+
"w_meta": 0.2,
|
| 13 |
+
"k": 10,
|
| 14 |
+
"faiss_name": "beauty_weighted"
|
| 15 |
+
}
|
| 16 |
+
}
|
data/processed/beauty/index/defaults_cove.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cove_faiss_only": {
|
| 3 |
+
"k": 10,
|
| 4 |
+
"faiss_name": "beauty_cove_faiss_only"
|
| 5 |
+
},
|
| 6 |
+
"cove_faiss_concat": {
|
| 7 |
+
"w_text": 0.2,
|
| 8 |
+
"w_image": 0.2,
|
| 9 |
+
"w_meta": 0.2,
|
| 10 |
+
"w_cove": 0.4,
|
| 11 |
+
"k": 10,
|
| 12 |
+
"faiss_name": "beauty_cove_faiss_concat"
|
| 13 |
+
},
|
| 14 |
+
"cove_faiss_weighted": {
|
| 15 |
+
"w_text": 0.2,
|
| 16 |
+
"w_image": 0.2,
|
| 17 |
+
"w_meta": 0.2,
|
| 18 |
+
"w_cove": 0.4,
|
| 19 |
+
"k": 10,
|
| 20 |
+
"faiss_name": "beauty_cove_faiss_weighted"
|
| 21 |
+
}
|
| 22 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.110.0
|
| 2 |
+
uvicorn==0.27.0.post1
|
| 3 |
+
pydantic==1.10.14
|
| 4 |
+
numpy==1.24.4
|
| 5 |
+
pandas==2.2.1
|
| 6 |
+
faiss-cpu==1.7.4
|
| 7 |
+
scikit-learn==1.4.0
|
| 8 |
+
tqdm==4.66.2
|
| 9 |
+
sentence-transformers==2.6.1
|
| 10 |
+
transformers==4.39.3
|
| 11 |
+
torch==2.1.2
|
| 12 |
+
protobuf==4.25.3
|
| 13 |
+
pyarrow==15.0.2
|
space.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# space.yaml
|
| 2 |
+
title: "CoVE API"
|
| 3 |
+
sdk: "docker"
|
| 4 |
+
app_file: "api/app_api.py"
|
| 5 |
+
python_version: "3.11"
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal package init to avoid import-time side effects
|
| 2 |
+
__all__ = []
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
src/agents/.ipynb_checkpoints/agent_types-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/base-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/chat_agent-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/data_agent-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/index_agent-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/model_agent-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/orchestrator-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/recommend_agent-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/report_agent-checkpoint.py
ADDED
|
File without changes
|
src/agents/.ipynb_checkpoints/run_agent-checkpoint.py
ADDED
|
File without changes
|
src/agents/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# auto-created to mark package
|
src/agents/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
src/agents/__pycache__/chat_agent.cpython-311.pyc
ADDED
|
Binary file (17.1 kB). View file
|
|
|
src/agents/__pycache__/orchestrator.cpython-311.pyc
ADDED
|
Binary file (17.3 kB). View file
|
|
|
src/agents/__pycache__/report_agent.cpython-311.pyc
ADDED
|
Binary file (18.5 kB). View file
|
|
|
src/agents/agent_types.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class Task:
|
| 6 |
+
intent: str # "prepare" | "index" | "eval" | "recommend" | "report"
|
| 7 |
+
dataset: str = "beauty"
|
| 8 |
+
user: Optional[str] = None
|
| 9 |
+
k: int = 10
|
| 10 |
+
fusion: str = "concat"
|
| 11 |
+
w_text: float = 1.0
|
| 12 |
+
w_image: float = 1.0
|
| 13 |
+
w_meta: float = 0.0
|
| 14 |
+
use_faiss: bool = True
|
| 15 |
+
faiss_name: Optional[str] = None
|
| 16 |
+
exclude_seen: bool = True
|
src/agents/base.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
from .types import Task, StepResult
|
| 4 |
+
|
| 5 |
+
class BaseAgent(ABC):
|
| 6 |
+
name: str = "base"
|
| 7 |
+
|
| 8 |
+
@abstractmethod
|
| 9 |
+
def run(self, task: Task) -> StepResult:
|
| 10 |
+
...
|
| 11 |
+
|
| 12 |
+
def ok(self, detail: str = "", **artifacts) -> StepResult:
|
| 13 |
+
return StepResult(name=self.name, status="succeeded", detail=detail, artifacts=artifacts)
|
| 14 |
+
|
| 15 |
+
def fail(self, detail: str = "", **artifacts) -> StepResult:
|
| 16 |
+
return StepResult(name=self.name, status="failed", detail=detail, artifacts=artifacts)
|
src/agents/chat_agent.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/agents/chat_agent.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import ast
|
| 5 |
+
import math
|
| 6 |
+
import re
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
from src.utils.paths import get_processed_path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ----------------------------- simple config -----------------------------
|
| 16 |
+
@dataclass
|
| 17 |
+
class ChatAgentConfig:
|
| 18 |
+
# words to ignore when pulling a keyword from the prompt
|
| 19 |
+
stopwords: frozenset = frozenset(
|
| 20 |
+
{
|
| 21 |
+
"under", "below", "less", "than", "beneath",
|
| 22 |
+
"recommend", "something", "for", "me", "i", "need", "want",
|
| 23 |
+
"a", "an", "the", "please", "pls", "ok", "okay",
|
| 24 |
+
"price", "priced", "cost", "costing", "buy", "find", "search",
|
| 25 |
+
"show", "give", "with", "and", "or", "of", "to", "in", "on",
|
| 26 |
+
}
|
| 27 |
+
)
|
| 28 |
+
# price pattern: $12, 12, 12.5
|
| 29 |
+
price_re: re.Pattern = re.compile(r"\$?\s*([0-9]+(?:\.[0-9]+)?)", re.IGNORECASE)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ----------------------------- helpers -----------------------------------
|
| 33 |
+
def _safe_float(x) -> Optional[float]:
|
| 34 |
+
try:
|
| 35 |
+
if x is None:
|
| 36 |
+
return None
|
| 37 |
+
s = str(x).strip()
|
| 38 |
+
# Strip $ and commas if present (common in meta)
|
| 39 |
+
s = s.replace(",", "")
|
| 40 |
+
if s.startswith("$"):
|
| 41 |
+
s = s[1:]
|
| 42 |
+
v = float(s)
|
| 43 |
+
if not math.isfinite(v):
|
| 44 |
+
return None
|
| 45 |
+
return v
|
| 46 |
+
except Exception:
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _fmt_price(v: float) -> str:
|
| 51 |
+
try:
|
| 52 |
+
return f"${float(v):.2f}"
|
| 53 |
+
except Exception:
|
| 54 |
+
return f"${v}"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _normalize_categories(val) -> List[str]:
|
| 58 |
+
"""
|
| 59 |
+
Normalize 'categories' to list[str], handling:
|
| 60 |
+
- None
|
| 61 |
+
- list/tuple/set of str
|
| 62 |
+
- stringified lists like "['A','B']" OR ["['A','B']"]
|
| 63 |
+
- delimited strings "A > B, C; D"
|
| 64 |
+
"""
|
| 65 |
+
def _from_string(s: str):
|
| 66 |
+
s = s.strip()
|
| 67 |
+
# Try literal list/tuple: "['A','B']" / '["A","B"]' / "(A,B)"
|
| 68 |
+
if (s.startswith("[") and s.endswith("]")) or (s.startswith("(") and s.endswith(")")):
|
| 69 |
+
try:
|
| 70 |
+
parsed = ast.literal_eval(s)
|
| 71 |
+
if isinstance(parsed, (list, tuple, set)):
|
| 72 |
+
return [str(x).strip() for x in parsed if x is not None and str(x).strip()]
|
| 73 |
+
except Exception:
|
| 74 |
+
pass
|
| 75 |
+
# Delimited fallback
|
| 76 |
+
if re.search(r"[>|,/;]+", s):
|
| 77 |
+
return [p.strip() for p in re.split(r"[>|,/;]+", s) if p.strip()]
|
| 78 |
+
return [s] if s else []
|
| 79 |
+
|
| 80 |
+
if val is None:
|
| 81 |
+
return []
|
| 82 |
+
|
| 83 |
+
# Already a container?
|
| 84 |
+
if isinstance(val, (list, tuple, set)):
|
| 85 |
+
out = []
|
| 86 |
+
for x in val:
|
| 87 |
+
if x is None:
|
| 88 |
+
continue
|
| 89 |
+
if isinstance(x, (list, tuple, set)):
|
| 90 |
+
# flatten nested containers
|
| 91 |
+
for y in x:
|
| 92 |
+
if y is None:
|
| 93 |
+
continue
|
| 94 |
+
if isinstance(y, (list, tuple, set)):
|
| 95 |
+
out.extend([str(z).strip() for z in y if z is not None and str(z).strip()])
|
| 96 |
+
elif isinstance(y, str):
|
| 97 |
+
out.extend(_from_string(y))
|
| 98 |
+
else:
|
| 99 |
+
out.append(str(y).strip())
|
| 100 |
+
elif isinstance(x, str):
|
| 101 |
+
out.extend(_from_string(x))
|
| 102 |
+
else:
|
| 103 |
+
out.append(str(x).strip())
|
| 104 |
+
# dedupe + keep order
|
| 105 |
+
seen, dedup = set(), []
|
| 106 |
+
for c in out:
|
| 107 |
+
if c and c not in seen:
|
| 108 |
+
seen.add(c)
|
| 109 |
+
dedup.append(c)
|
| 110 |
+
return dedup
|
| 111 |
+
|
| 112 |
+
# Scalar string
|
| 113 |
+
return _from_string(str(val))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ----------------------------- agent --------------------------------------
|
| 117 |
+
class ChatAgent:
|
| 118 |
+
def __init__(self, config: Optional[ChatAgentConfig] = None) -> None:
|
| 119 |
+
self.config = config or ChatAgentConfig()
|
| 120 |
+
|
| 121 |
+
# ---- parse last user text ----
|
| 122 |
+
def _parse_price_cap(self, text: str) -> Optional[float]:
|
| 123 |
+
m = self.config.price_re.search(text or "")
|
| 124 |
+
if not m:
|
| 125 |
+
return None
|
| 126 |
+
return _safe_float(m.group(1))
|
| 127 |
+
|
| 128 |
+
def _parse_keyword(self, text: str) -> Optional[str]:
|
| 129 |
+
t = (text or "").lower()
|
| 130 |
+
# remove price fragments
|
| 131 |
+
t = self.config.price_re.sub(" ", t)
|
| 132 |
+
# pick first token that isn't a stopword and has letters
|
| 133 |
+
for w in re.findall(r"[a-z][a-z0-9\-]+", t):
|
| 134 |
+
if w in self.config.stopwords:
|
| 135 |
+
continue
|
| 136 |
+
return w
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
# ---- load catalog ----
|
| 140 |
+
def _items_df(self, dataset: str) -> pd.DataFrame:
|
| 141 |
+
"""
|
| 142 |
+
Load the product catalog from processed data.
|
| 143 |
+
Prefers items_with_meta.parquet (your structure), falls back to joined.parquet.
|
| 144 |
+
Returns a DataFrame; missing columns are filled with sensible defaults.
|
| 145 |
+
"""
|
| 146 |
+
proc = get_processed_path(dataset)
|
| 147 |
+
for fname in ["items_with_meta.parquet", "joined.parquet", "items_meta.parquet", "items.parquet"]:
|
| 148 |
+
fp = proc / fname
|
| 149 |
+
if fp.exists():
|
| 150 |
+
try:
|
| 151 |
+
df = pd.read_parquet(fp)
|
| 152 |
+
break
|
| 153 |
+
except Exception:
|
| 154 |
+
continue
|
| 155 |
+
else:
|
| 156 |
+
# nothing found
|
| 157 |
+
return pd.DataFrame(columns=["item_id", "title", "brand", "price", "categories", "image_url"])
|
| 158 |
+
|
| 159 |
+
# Make sure expected columns exist
|
| 160 |
+
for col in ["item_id", "title", "brand", "price", "categories", "image_url"]:
|
| 161 |
+
if col not in df.columns:
|
| 162 |
+
df[col] = None
|
| 163 |
+
|
| 164 |
+
# Some pipelines store images under imageURL/imageURLHighRes
|
| 165 |
+
if ("image_url" not in df.columns or df["image_url"].isna().all()):
|
| 166 |
+
for alt in ("imageURLHighRes", "imageURL"):
|
| 167 |
+
if alt in df.columns:
|
| 168 |
+
# pick first image if it's a list-like
|
| 169 |
+
def _first_img(v):
|
| 170 |
+
if isinstance(v, (list, tuple)) and v:
|
| 171 |
+
return v[0]
|
| 172 |
+
return v
|
| 173 |
+
df["image_url"] = df[alt].apply(_first_img)
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
return df
|
| 177 |
+
|
| 178 |
+
# --------- main entrypoint expected by API ---------
|
| 179 |
+
def reply(
|
| 180 |
+
self,
|
| 181 |
+
messages: List[Dict[str, str]],
|
| 182 |
+
dataset: Optional[str] = None,
|
| 183 |
+
user_id: Optional[str] = None, # unused in this simple baseline
|
| 184 |
+
k: int = 5,
|
| 185 |
+
) -> Dict[str, Any]:
|
| 186 |
+
"""
|
| 187 |
+
Baseline behavior:
|
| 188 |
+
- Parse last user message → (keyword, price cap)
|
| 189 |
+
- Filter catalog by price<=cap and keyword match in title/brand/categories
|
| 190 |
+
- Rank by lowest price (as a proxy score)
|
| 191 |
+
- Return top-k with normalized fields
|
| 192 |
+
"""
|
| 193 |
+
if not dataset:
|
| 194 |
+
dataset = "beauty"
|
| 195 |
+
|
| 196 |
+
# last user utterance
|
| 197 |
+
last_user = ""
|
| 198 |
+
for m in reversed(messages or []):
|
| 199 |
+
if (m.get("role") or "").lower() == "user":
|
| 200 |
+
last_user = m.get("content") or ""
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
cap = self._parse_price_cap(last_user)
|
| 204 |
+
kw = self._parse_keyword(last_user)
|
| 205 |
+
|
| 206 |
+
df = self._items_df(dataset)
|
| 207 |
+
|
| 208 |
+
# Column presence map for debugging
|
| 209 |
+
colmap = {
|
| 210 |
+
"item_id": "item_id" if "item_id" in df.columns else None,
|
| 211 |
+
"title": "title" if "title" in df.columns else None,
|
| 212 |
+
"brand": "brand" if "brand" in df.columns else None,
|
| 213 |
+
"price": "price" if "price" in df.columns else None,
|
| 214 |
+
"categories": "categories" if "categories" in df.columns else None,
|
| 215 |
+
"image_url": "image_url" if "image_url" in df.columns else None,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
# ------- filtering -------
|
| 219 |
+
if len(df) == 0:
|
| 220 |
+
sub = df
|
| 221 |
+
else:
|
| 222 |
+
mask = pd.Series(True, index=df.index)
|
| 223 |
+
|
| 224 |
+
# price filter
|
| 225 |
+
if cap is not None and colmap["price"]:
|
| 226 |
+
price_num = df[colmap["price"]].apply(_safe_float)
|
| 227 |
+
mask &= pd.to_numeric(price_num, errors="coerce").le(cap)
|
| 228 |
+
|
| 229 |
+
# keyword filter (title OR brand OR categories)
|
| 230 |
+
if kw:
|
| 231 |
+
kw_l = kw.lower()
|
| 232 |
+
parts = []
|
| 233 |
+
if colmap["title"]:
|
| 234 |
+
parts.append(df[colmap["title"]].astype(str).str.lower().str.contains(kw_l, na=False))
|
| 235 |
+
if colmap["brand"]:
|
| 236 |
+
parts.append(df[colmap["brand"]].astype(str).str.lower().str.contains(kw_l, na=False))
|
| 237 |
+
if colmap["categories"]:
|
| 238 |
+
parts.append(df[colmap["categories"]].astype(str).str.lower().str.contains(kw_l, na=False))
|
| 239 |
+
if parts:
|
| 240 |
+
m_any = parts[0]
|
| 241 |
+
for p in parts[1:]:
|
| 242 |
+
m_any = m_any | p
|
| 243 |
+
mask &= m_any
|
| 244 |
+
|
| 245 |
+
sub = df[mask].copy()
|
| 246 |
+
|
| 247 |
+
# ------- scoring & sorting (cheaper → higher score) -------
|
| 248 |
+
if len(sub) > 0:
|
| 249 |
+
price_num = sub[colmap["price"]].apply(_safe_float) if colmap["price"] else 0.0
|
| 250 |
+
sub["score"] = pd.to_numeric(price_num, errors="coerce").apply(
|
| 251 |
+
lambda p: 1.0 / (p + 1e-6) if pd.notnull(p) and p > 0 else 0.0
|
| 252 |
+
)
|
| 253 |
+
sort_cols = ["score"]
|
| 254 |
+
ascending = [False]
|
| 255 |
+
if colmap["brand"]:
|
| 256 |
+
sort_cols.append(colmap["brand"])
|
| 257 |
+
ascending.append(True)
|
| 258 |
+
if colmap["title"]:
|
| 259 |
+
sort_cols.append(colmap["title"])
|
| 260 |
+
ascending.append(True)
|
| 261 |
+
sub = sub.sort_values(by=sort_cols, ascending=ascending).head(max(1, int(k)))
|
| 262 |
+
|
| 263 |
+
# ------- build recs -------
|
| 264 |
+
recs: List[Dict[str, Any]] = []
|
| 265 |
+
for _, r in sub.iterrows():
|
| 266 |
+
recs.append(
|
| 267 |
+
{
|
| 268 |
+
"item_id": r.get(colmap["item_id"]) if colmap["item_id"] else None,
|
| 269 |
+
"score": float(r.get("score") or 0.0),
|
| 270 |
+
"brand": (r.get(colmap["brand"]) if colmap["brand"] else None) or None,
|
| 271 |
+
"price": _safe_float(r.get(colmap["price"]) if colmap["price"] else None),
|
| 272 |
+
"categories": _normalize_categories(r.get(colmap["categories"]) if colmap["categories"] else None),
|
| 273 |
+
"image_url": (r.get(colmap["image_url"]) if colmap["image_url"] else None) or None,
|
| 274 |
+
}
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Fallback: if filter empty, return cheapest k overall
|
| 278 |
+
if not recs and len(df) > 0:
|
| 279 |
+
df2 = df.copy()
|
| 280 |
+
pnum = df2[colmap["price"]].apply(_safe_float) if colmap["price"] else None
|
| 281 |
+
df2["pnum"] = pd.to_numeric(pnum, errors="coerce")
|
| 282 |
+
df2 = df2.sort_values(by=["pnum"]).head(max(1, int(k)))
|
| 283 |
+
for _, r in df2.iterrows():
|
| 284 |
+
recs.append(
|
| 285 |
+
{
|
| 286 |
+
"item_id": r.get(colmap["item_id"]) if colmap["item_id"] else None,
|
| 287 |
+
"score": 0.0,
|
| 288 |
+
"brand": (r.get(colmap["brand"]) if colmap["brand"] else None) or None,
|
| 289 |
+
"price": _safe_float(r.get(colmap["price"]) if colmap["price"] else None),
|
| 290 |
+
"categories": _normalize_categories(r.get(colmap["categories"]) if colmap["categories"] else None),
|
| 291 |
+
"image_url": (r.get(colmap["image_url"]) if colmap["image_url"] else None) or None,
|
| 292 |
+
}
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# reply sentence
|
| 296 |
+
reply_bits = []
|
| 297 |
+
if kw:
|
| 298 |
+
reply_bits.append(f"**{kw}**")
|
| 299 |
+
if cap is not None:
|
| 300 |
+
reply_bits.append(f"≤ {_fmt_price(cap)}")
|
| 301 |
+
reply_str = "I found items " + (" ".join(reply_bits) if reply_bits else "you might like") + f" on **{dataset}**."
|
| 302 |
+
|
| 303 |
+
# Helpful debug
|
| 304 |
+
debug = {
|
| 305 |
+
"parsed_keyword": kw,
|
| 306 |
+
"price_cap": cap,
|
| 307 |
+
"matched": len(recs),
|
| 308 |
+
"colmap": colmap,
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
return {"reply": reply_str, "recommendations": recs, "debug": debug}
|
src/agents/data_agent.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/agents/data_agent.py
|
| 2 |
+
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Literal
|
| 8 |
+
import urllib.request
|
| 9 |
+
|
| 10 |
+
class DataAgent:
|
| 11 |
+
"""
|
| 12 |
+
Runs data prep scripts for a dataset:
|
| 13 |
+
- Downloads raw files if not present
|
| 14 |
+
- join_meta.py
|
| 15 |
+
- build_text_emb.py
|
| 16 |
+
- build_image_emb.py
|
| 17 |
+
- build_meta_emb.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def _run(self, argv):
|
| 21 |
+
print("→", " ".join(argv))
|
| 22 |
+
subprocess.check_call(argv)
|
| 23 |
+
|
| 24 |
+
def _download_raw_data(self, dataset: str):
|
| 25 |
+
if dataset != "beauty":
|
| 26 |
+
raise ValueError(f"Auto-download is only supported for 'beauty' dataset")
|
| 27 |
+
|
| 28 |
+
base_dir = Path("data/raw/beauty")
|
| 29 |
+
base_dir.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
files = {
|
| 32 |
+
"reviews.json": "https://huggingface.co/datasets/mickey1976/mayankc-amazon_beauty_subset/resolve/main/reviews.json",
|
| 33 |
+
"meta.json": "https://huggingface.co/datasets/mickey1976/mayankc-amazon_beauty_subset/resolve/main/meta.json",
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
for fname, url in files.items():
|
| 37 |
+
out_path = base_dir / fname
|
| 38 |
+
if not out_path.exists():
|
| 39 |
+
print(f"⬇️ Downloading {fname}...")
|
| 40 |
+
urllib.request.urlretrieve(url, out_path)
|
| 41 |
+
print(f"✅ Saved to {out_path}")
|
| 42 |
+
else:
|
| 43 |
+
print(f"✔️ Already exists: {out_path}")
|
| 44 |
+
|
| 45 |
+
def prepare(self, dataset: Literal["beauty"] = "beauty"):
|
| 46 |
+
print(f"
|
src/agents/index_agent.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/agents/index_agent.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class IndexConfig:
|
| 9 |
+
dataset: str
|
| 10 |
+
fusion: str = "concat" # "concat" | "weighted"
|
| 11 |
+
w_text: float = 1.0
|
| 12 |
+
w_image: float = 1.0
|
| 13 |
+
w_meta: float = 0.0
|
| 14 |
+
out_name: str = "" # e.g. "beauty_concat_best"
|
| 15 |
+
|
| 16 |
+
class IndexAgent:
|
| 17 |
+
def _run(self, argv: list[str]) -> None:
|
| 18 |
+
# Run the CLI step in the same interpreter/venv
|
| 19 |
+
subprocess.check_call(argv)
|
| 20 |
+
|
| 21 |
+
def build(self, cfg: IndexConfig) -> None:
|
| 22 |
+
args = [
|
| 23 |
+
sys.executable, "scripts/build_faiss.py",
|
| 24 |
+
"--dataset", cfg.dataset,
|
| 25 |
+
"--fusion", cfg.fusion,
|
| 26 |
+
"--w_text", str(cfg.w_text),
|
| 27 |
+
"--w_image", str(cfg.w_image),
|
| 28 |
+
"--w_meta", str(cfg.w_meta),
|
| 29 |
+
]
|
| 30 |
+
if cfg.out_name:
|
| 31 |
+
args += ["--out_name", cfg.out_name]
|
| 32 |
+
print("→", " ".join(args))
|
| 33 |
+
self._run(args)
|
| 34 |
+
print("✓ Index build complete.")
|
src/agents/model_agent.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess, sys
|
| 2 |
+
|
| 3 |
+
class ModelAgent:
|
| 4 |
+
"""Runs evaluation / sweeps for fusion strategies."""
|
| 5 |
+
def eval(self, dataset: str="beauty"):
|
| 6 |
+
print("→ eval fusion on", dataset)
|
| 7 |
+
subprocess.check_call([sys.executable, "scripts/eval_fusion.py", "--dataset", dataset])
|
| 8 |
+
print("✓ Evaluation complete.")
|
src/agents/orchestrator.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# src/agents/orchestrator.py
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import subprocess
|
| 6 |
+
|
| 7 |
+
def run_eval(dataset: str):
|
| 8 |
+
runs = [
|
| 9 |
+
# 1. No FAISS - Weighted Fusion
|
| 10 |
+
["scripts/eval_fusion.py", "--dataset", dataset, "--fusion", "weighted", "--use_defaults", "--k", "10", "--run_name", "weighted"],
|
| 11 |
+
|
| 12 |
+
# 2. No FAISS - Concat Fusion
|
| 13 |
+
["scripts/eval_fusion.py", "--dataset", dataset, "--fusion", "concat", "--use_defaults", "--k", "10", "--run_name", "concat"],
|
| 14 |
+
|
| 15 |
+
# 3. FAISS - Weighted Fusion
|
| 16 |
+
["scripts/eval_fusion.py", "--dataset", dataset, "--fusion", "weighted", "--use_defaults", "--use_faiss", "--k", "10", "--run_name", "cove_faiss_weighted"],
|
| 17 |
+
|
| 18 |
+
# 4. FAISS - Concat Fusion
|
| 19 |
+
["scripts/eval_fusion.py", "--dataset", dataset, "--fusion", "concat", "--use_defaults", "--use_faiss", "--k", "10", "--run_name", "cove_faiss_concat"],
|
| 20 |
+
|
| 21 |
+
# 5. CoVE FAISS Only + Logits
|
| 22 |
+
["scripts/eval_cove.py", "--dataset", dataset, "--mode", "cove_faiss_only", "--save_candidates"],
|
| 23 |
+
["scripts/eval_logits_cove.py", dataset],
|
| 24 |
+
|
| 25 |
+
# 6. CoVE FAISS Concat + Logits
|
| 26 |
+
["scripts/eval_cove.py", "--dataset", dataset, "--mode", "cove_faiss_concat", "--save_candidates"],
|
| 27 |
+
["scripts/eval_logits_cove.py", dataset],
|
| 28 |
+
|
| 29 |
+
# 7. Full CoVE Logits (pure model)
|
| 30 |
+
["scripts/eval_cove.py", "--dataset", dataset, "--mode", "cove_logits", "--full"],
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
for i, cmd in enumerate(runs, 1):
|
| 34 |
+
print(f"\n[🚀] Running {i}/{len(runs)}: {' '.join(cmd)}")
|
| 35 |
+
subprocess.run(["PYTHONPATH=./src"] + cmd, check=True, shell=False)
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
parser = argparse.ArgumentParser()
|
| 39 |
+
parser.add_argument("--dataset", required=True)
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
run_eval(args.dataset)
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
main()
|
src/agents/recommend_agent.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import urllib.request
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
class RecommendAgent:
|
| 6 |
+
"""
|
| 7 |
+
Hits your local FastAPI recommender (port 8000).
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, api_base: str="http://127.0.0.1:8000"):
|
| 10 |
+
self.api_base = api_base.rstrip("/")
|
| 11 |
+
|
| 12 |
+
def recommend(self,
|
| 13 |
+
dataset: str,
|
| 14 |
+
user: str,
|
| 15 |
+
k: int = 10,
|
| 16 |
+
fusion: str = "concat",
|
| 17 |
+
w_text: float = 1.0,
|
| 18 |
+
w_image: float = 1.0,
|
| 19 |
+
w_meta: float = 0.0,
|
| 20 |
+
use_faiss: bool = True,
|
| 21 |
+
faiss_name: Optional[str] = None,
|
| 22 |
+
exclude_seen: bool = True):
|
| 23 |
+
payload = {
|
| 24 |
+
"dataset": dataset, "user_id": user, "k": k,
|
| 25 |
+
"fusion": fusion, "w_text": w_text, "w_image": w_image, "w_meta": w_meta,
|
| 26 |
+
"use_faiss": use_faiss, "exclude_seen": exclude_seen
|
| 27 |
+
}
|
| 28 |
+
if use_faiss and faiss_name:
|
| 29 |
+
payload["faiss_name"] = faiss_name
|
| 30 |
+
|
| 31 |
+
url = f"{self.api_base}/recommend"
|
| 32 |
+
req = urllib.request.Request(url, data=json.dumps(payload).encode("utf-8"),
|
| 33 |
+
headers={"Content-Type": "application/json"})
|
| 34 |
+
with urllib.request.urlopen(req) as resp:
|
| 35 |
+
body = resp.read()
|
| 36 |
+
data = json.loads(body)
|
| 37 |
+
return data
|
src/agents/report_agent.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# src/agents/report_agent.py
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Dict, Any, Optional
|
| 10 |
+
|
| 11 |
+
import pandas as pd
|
| 12 |
+
|
| 13 |
+
HERE = Path(__file__).resolve().parent
|
| 14 |
+
ROOT = HERE.parents[2] # repo root (/notebooks/MMR-Agentic-CoVE)
|
| 15 |
+
LOGS = ROOT / "logs"
|
| 16 |
+
PLOTS = LOGS / "plots"
|
| 17 |
+
REPORTS_ROOT = ROOT / "reports"
|
| 18 |
+
|
| 19 |
+
def _ensure_dir(p: Path):
|
| 20 |
+
p.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
def _load_metrics(csv_fp: Path) -> pd.DataFrame:
|
| 23 |
+
if not csv_fp.exists():
|
| 24 |
+
raise FileNotFoundError(f"Missing metrics CSV: {csv_fp}")
|
| 25 |
+
df = pd.read_csv(csv_fp, engine="python", on_bad_lines="skip")
|
| 26 |
+
# normalize
|
| 27 |
+
for col in ["run_name", "dataset", "fusion"]:
|
| 28 |
+
if col not in df.columns:
|
| 29 |
+
df[col] = ""
|
| 30 |
+
df[col] = df[col].fillna("").astype(str)
|
| 31 |
+
for wcol in ["w_text", "w_image", "w_meta"]:
|
| 32 |
+
if wcol not in df.columns:
|
| 33 |
+
df[wcol] = float("nan")
|
| 34 |
+
# convenience flags
|
| 35 |
+
if "faiss" not in df.columns:
|
| 36 |
+
df["faiss"] = df["run_name"].str.contains("faiss", case=False, na=False).astype(bool)
|
| 37 |
+
return df
|
| 38 |
+
|
| 39 |
+
def _metric_cols(df: pd.DataFrame, k: int) -> Dict[str, Optional[str]]:
|
| 40 |
+
# prefer explicit @k
|
| 41 |
+
hit_col = f"hit@{k}" if f"hit@{k}" in df.columns else ("hit" if "hit" in df.columns else None)
|
| 42 |
+
ndcg_col = f"ndcg@{k}" if f"ndcg@{k}" in df.columns else ("ndcg" if "ndcg" in df.columns else None)
|
| 43 |
+
return {"hit": hit_col, "ndcg": ndcg_col}
|
| 44 |
+
|
| 45 |
+
def _top_n_table(
|
| 46 |
+
df: pd.DataFrame, dataset: str, k: int, top_n: int = 5,
|
| 47 |
+
prefer_faiss: bool = True
|
| 48 |
+
) -> tuple[pd.DataFrame, Dict[str, Any]]:
|
| 49 |
+
df = df.copy()
|
| 50 |
+
df = df[df["dataset"] == dataset] if "dataset" in df.columns else df
|
| 51 |
+
|
| 52 |
+
cols = _metric_cols(df, k)
|
| 53 |
+
hitc, ndcgc = cols["hit"], cols["ndcg"]
|
| 54 |
+
if not ndcgc and not hitc:
|
| 55 |
+
raise ValueError(f"No hit/ndcg columns found for k={k}. Available: {list(df.columns)}")
|
| 56 |
+
|
| 57 |
+
# sort keys: ndcg desc, then hit desc; optional FAISS preference when tied
|
| 58 |
+
sort_cols = []
|
| 59 |
+
if ndcgc: sort_cols.append(ndcgc)
|
| 60 |
+
if hitc: sort_cols.append(hitc)
|
| 61 |
+
if not sort_cols:
|
| 62 |
+
raise ValueError("No sortable metric columns.")
|
| 63 |
+
df["_faiss"] = df.get("faiss", df["run_name"].str.contains("faiss", case=False, na=False)).astype(int)
|
| 64 |
+
by = [c for c in sort_cols] + (["_faiss"] if prefer_faiss else [])
|
| 65 |
+
df_sorted = df.sort_values(by=by, ascending=[False]*len(by))
|
| 66 |
+
|
| 67 |
+
# build a compact table for the report
|
| 68 |
+
keep_cols = ["run_name", "dataset", "fusion", "w_text", "w_image", "w_meta"]
|
| 69 |
+
if hitc: keep_cols.append(hitc)
|
| 70 |
+
if ndcgc: keep_cols.append(ndcgc)
|
| 71 |
+
|
| 72 |
+
top = df_sorted[keep_cols].head(top_n).reset_index(drop=True)
|
| 73 |
+
|
| 74 |
+
# choose recommendation = first row
|
| 75 |
+
rec_row = top.iloc[0].to_dict()
|
| 76 |
+
rec = {
|
| 77 |
+
"dataset": dataset,
|
| 78 |
+
"k": k,
|
| 79 |
+
"recommended_run": rec_row["run_name"],
|
| 80 |
+
"fusion": rec_row.get("fusion"),
|
| 81 |
+
"weights": {
|
| 82 |
+
"w_text": float(rec_row.get("w_text")) if pd.notna(rec_row.get("w_text")) else None,
|
| 83 |
+
"w_image": float(rec_row.get("w_image")) if pd.notna(rec_row.get("w_image")) else None,
|
| 84 |
+
"w_meta": float(rec_row.get("w_meta")) if pd.notna(rec_row.get("w_meta")) else None,
|
| 85 |
+
},
|
| 86 |
+
"metrics": {
|
| 87 |
+
(hitc or "hit"): float(rec_row.get(hitc)) if hitc else None,
|
| 88 |
+
(ndcgc or "ndcg"): float(rec_row.get(ndcgc)) if ndcgc else None,
|
| 89 |
+
},
|
| 90 |
+
}
|
| 91 |
+
return top, rec
|
| 92 |
+
|
| 93 |
+
def _md_table(df: pd.DataFrame) -> str:
|
| 94 |
+
"""
|
| 95 |
+
Return a markdown-ish table. Falls back to a preformatted text block if
|
| 96 |
+
pandas' to_markdown requires 'tabulate' and it's not installed.
|
| 97 |
+
"""
|
| 98 |
+
try:
|
| 99 |
+
return df.to_markdown(index=False)
|
| 100 |
+
except Exception:
|
| 101 |
+
# Fallback: plain text inside code fences so the report still renders.
|
| 102 |
+
return "```\n" + df.to_string(index=False) + "\n```"
|
| 103 |
+
def _copy_plots_into(out_dir: Path, dataset: str) -> list[str]:
|
| 104 |
+
"""
|
| 105 |
+
Return the list of plot filenames that were copied into out_dir.
|
| 106 |
+
Only copies files that exist under logs/plots.
|
| 107 |
+
"""
|
| 108 |
+
wanted = [
|
| 109 |
+
f"{dataset}_k10_quality.png",
|
| 110 |
+
f"{dataset}_k10_quality_trend.png",
|
| 111 |
+
f"{dataset}_k10_latency.png",
|
| 112 |
+
f"{dataset}_w_meta_ndcg@10.png",
|
| 113 |
+
f"{dataset}_w_meta_hit@10.png",
|
| 114 |
+
f"{dataset}_k_ndcg@10.png",
|
| 115 |
+
]
|
| 116 |
+
copied: list[str] = []
|
| 117 |
+
for name in wanted:
|
| 118 |
+
src = PLOTS / name
|
| 119 |
+
if src.exists():
|
| 120 |
+
try:
|
| 121 |
+
import shutil
|
| 122 |
+
dst = out_dir / name
|
| 123 |
+
shutil.copy2(src, dst)
|
| 124 |
+
copied.append(name)
|
| 125 |
+
except Exception:
|
| 126 |
+
pass
|
| 127 |
+
return copied
|
| 128 |
+
|
| 129 |
+
def _baseline_quadrant(df: pd.DataFrame, dataset: str, k: int) -> Optional[pd.DataFrame]:
|
| 130 |
+
"""
|
| 131 |
+
Build a compact 2x2 comparison if rows exist:
|
| 132 |
+
No-FAISS / FAISS × concat / weighted
|
| 133 |
+
"""
|
| 134 |
+
cols = _metric_cols(df, k)
|
| 135 |
+
hitc, ndcgc = cols["hit"], cols["ndcg"]
|
| 136 |
+
if not ndcgc and not hitc:
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
d = df.copy()
|
| 140 |
+
if "dataset" in d.columns:
|
| 141 |
+
d = d[d["dataset"] == dataset]
|
| 142 |
+
if "fusion" not in d.columns:
|
| 143 |
+
return None
|
| 144 |
+
if "faiss" not in d.columns:
|
| 145 |
+
d["faiss"] = d["run_name"].str.contains("faiss", case=False, na=False).astype(bool)
|
| 146 |
+
|
| 147 |
+
# For each quadrant, pick the best row (by ndcg then hit)
|
| 148 |
+
rows = []
|
| 149 |
+
for fa in [False, True]:
|
| 150 |
+
for fu in ["concat", "weighted"]:
|
| 151 |
+
sub = d[(d["fusion"].str.lower()==fu) & (d["faiss"]==fa)]
|
| 152 |
+
if ndcgc: sub = sub.sort_values(ndcgc, ascending=False)
|
| 153 |
+
if hitc:
|
| 154 |
+
sub = sub.sort_values([ndcgc, hitc], ascending=[False, False]) if ndcgc else sub.sort_values(hitc, ascending=False)
|
| 155 |
+
if sub.empty:
|
| 156 |
+
rows.append({"faiss": "Yes" if fa else "No", "fusion": fu, "run_name": "—",
|
| 157 |
+
"hit@k": None if not hitc else None, "ndcg@k": None if not ndcgc else None})
|
| 158 |
+
else:
|
| 159 |
+
r = sub.iloc[0]
|
| 160 |
+
rows.append({
|
| 161 |
+
"faiss": "Yes" if fa else "No",
|
| 162 |
+
"fusion": fu,
|
| 163 |
+
"run_name": r.get("run_name", ""),
|
| 164 |
+
"hit@k": (float(r[hitc]) if hitc else None),
|
| 165 |
+
"ndcg@k": (float(r[ndcgc]) if ndcgc else None),
|
| 166 |
+
})
|
| 167 |
+
out = pd.DataFrame(rows, columns=["faiss","fusion","run_name","hit@k","ndcg@k"])
|
| 168 |
+
# Return None if literally no metrics found
|
| 169 |
+
if out[["hit@k","ndcg@k"]].isna().all().all():
|
| 170 |
+
return None
|
| 171 |
+
return out
|
| 172 |
+
|
| 173 |
+
def _write_report(
|
| 174 |
+
out_dir: Path,
|
| 175 |
+
tag: str,
|
| 176 |
+
dataset: str,
|
| 177 |
+
k: Optional[int],
|
| 178 |
+
include_compare: bool,
|
| 179 |
+
top_n: int,
|
| 180 |
+
prefer_faiss: bool,
|
| 181 |
+
metrics_csv: Path,
|
| 182 |
+
) -> None:
|
| 183 |
+
_ensure_dir(out_dir)
|
| 184 |
+
|
| 185 |
+
# Self-contained: copy plots into the report directory
|
| 186 |
+
copied_plots = _copy_plots_into(out_dir, dataset)
|
| 187 |
+
|
| 188 |
+
# optional compare section + recommendation
|
| 189 |
+
compare_md = ""
|
| 190 |
+
summary_json: Dict[str, Any] = {}
|
| 191 |
+
if include_compare and k is not None:
|
| 192 |
+
df_all = _load_metrics(metrics_csv)
|
| 193 |
+
try:
|
| 194 |
+
top, rec = _top_n_table(df_all, dataset=dataset, k=k, top_n=top_n, prefer_faiss=prefer_faiss)
|
| 195 |
+
compare_md = (
|
| 196 |
+
"## Top runs (auto)\n\n"
|
| 197 |
+
+ _md_table(top.rename(columns={
|
| 198 |
+
f"hit@{k}": "hit@k", f"ndcg@{k}": "ndcg@k"
|
| 199 |
+
})) + "\n\n"
|
| 200 |
+
"### Recommendation (auto)\n\n"
|
| 201 |
+
"```json\n" + json.dumps(rec, indent=2) + "\n```\n"
|
| 202 |
+
)
|
| 203 |
+
summary_json["recommendation"] = rec
|
| 204 |
+
summary_json["top_runs"] = json.loads(top.to_json(orient="records"))
|
| 205 |
+
|
| 206 |
+
# Add a 4-way baseline quadrant if possible
|
| 207 |
+
quad = _baseline_quadrant(df_all, dataset=dataset, k=k)
|
| 208 |
+
if quad is not None:
|
| 209 |
+
compare_md += "\n### Baseline 4-way comparison (FAISS × Fusion)\n\n"
|
| 210 |
+
compare_md += _md_table(quad) + "\n"
|
| 211 |
+
summary_json["baseline_quadrant"] = json.loads(quad.to_json(orient="records"))
|
| 212 |
+
except Exception as e:
|
| 213 |
+
compare_md = f"> Could not compute comparison for k={k}: {e}\n"
|
| 214 |
+
|
| 215 |
+
# build markdown
|
| 216 |
+
md_parts = [f"# {dataset} — {tag}\n"]
|
| 217 |
+
if include_compare and k is not None:
|
| 218 |
+
md_parts.append(compare_md)
|
| 219 |
+
|
| 220 |
+
if copied_plots:
|
| 221 |
+
md_parts.append("## Plots\n")
|
| 222 |
+
for name in copied_plots:
|
| 223 |
+
md_parts.append(f"\n")
|
| 224 |
+
|
| 225 |
+
# metrics snapshot (also save a filtered CSV into the report for grading)
|
| 226 |
+
try:
|
| 227 |
+
dfm = _load_metrics(metrics_csv)
|
| 228 |
+
snap = dfm[dfm["dataset"] == dataset] if "dataset" in dfm.columns else dfm
|
| 229 |
+
md_parts.append("## Metrics snapshot\n\n")
|
| 230 |
+
show_cols = [c for c in ["run_name","dataset","fusion","w_text","w_image","w_meta",
|
| 231 |
+
"k","hit","ndcg","hit@5","ndcg@5","hit@10","ndcg@10","hit@20","ndcg@20","p50_ms","p95_ms"]
|
| 232 |
+
if c in snap.columns]
|
| 233 |
+
if not show_cols:
|
| 234 |
+
show_cols = list(snap.columns)[:10]
|
| 235 |
+
md_parts.append(_md_table(snap[show_cols].tail(20)) + "\n")
|
| 236 |
+
# Save a compact CSV snapshot into the report folder
|
| 237 |
+
snap.to_csv(out_dir / "metrics.csv", index=False)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
md_parts.append(f"> Could not include metrics snapshot: {e}\n")
|
| 240 |
+
|
| 241 |
+
# write index.md
|
| 242 |
+
md_path = out_dir / "index.md"
|
| 243 |
+
md_path.write_text("\n".join(md_parts), encoding="utf-8")
|
| 244 |
+
|
| 245 |
+
# render HTML (pretty if markdown package available; otherwise fallback)
|
| 246 |
+
html_path = out_dir / "index.html"
|
| 247 |
+
try:
|
| 248 |
+
import markdown # type: ignore
|
| 249 |
+
html = markdown.markdown(md_path.read_text(encoding="utf-8"), extensions=["tables"])
|
| 250 |
+
html_full = [
|
| 251 |
+
"<html><head><meta charset='utf-8'><title>Report</title>",
|
| 252 |
+
"<style>body{font-family:ui-sans-serif,system-ui,-apple-system,Segoe UI,Roboto;max-width:900px;margin:40px auto;padding:0 16px} table{border-collapse:collapse} th,td{border:1px solid #ddd;padding:6px 8px}</style>",
|
| 253 |
+
"</head><body>",
|
| 254 |
+
html,
|
| 255 |
+
"</body></html>",
|
| 256 |
+
]
|
| 257 |
+
html_path.write_text("\n".join(html_full), encoding="utf-8")
|
| 258 |
+
except Exception:
|
| 259 |
+
# simple fallback
|
| 260 |
+
html = [
|
| 261 |
+
"<html><head><meta charset='utf-8'><title>Report</title></head><body>",
|
| 262 |
+
f"<pre style='font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, monospace'>{md_path.read_text(encoding='utf-8')}</pre>",
|
| 263 |
+
"</body></html>",
|
| 264 |
+
]
|
| 265 |
+
html_path.write_text("\n".join(html), encoding="utf-8")
|
| 266 |
+
|
| 267 |
+
# write summary.json
|
| 268 |
+
(out_dir / "summary.json").write_text(json.dumps({
|
| 269 |
+
"dataset": dataset,
|
| 270 |
+
"tag": tag,
|
| 271 |
+
"k": k,
|
| 272 |
+
"include_compare": include_compare,
|
| 273 |
+
**summary_json
|
| 274 |
+
}, indent=2), encoding="utf-8")
|
| 275 |
+
|
| 276 |
+
def main():
|
| 277 |
+
ap = argparse.ArgumentParser()
|
| 278 |
+
ap.add_argument("--dataset", required=True)
|
| 279 |
+
ap.add_argument("--tag", default="report")
|
| 280 |
+
ap.add_argument("--k", type=int, default=10, help="k to use for comparison tables")
|
| 281 |
+
ap.add_argument("--include-compare", action="store_true", help="Include Top runs + Recommendation section")
|
| 282 |
+
ap.add_argument("--top-n", type=int, default=3, help="How many runs to show in the Top table")
|
| 283 |
+
ap.add_argument("--prefer-faiss", action="store_true", help="Prefer FAISS runs when metrics tie")
|
| 284 |
+
ap.add_argument("--metrics_csv", default=str(LOGS / "metrics.csv"))
|
| 285 |
+
ap.add_argument("--plots_dir", default=str(PLOTS))
|
| 286 |
+
ap.add_argument("--out", default="", help="Optional explicit out path (file or directory)")
|
| 287 |
+
ap.add_argument("--no-plots", action="store_true", help="(kept for back-compat; plots are referenced if present)")
|
| 288 |
+
ap.add_argument("--zip", action="store_true", help="Zip the report folder")
|
| 289 |
+
args = ap.parse_args()
|
| 290 |
+
|
| 291 |
+
dataset = args.dataset
|
| 292 |
+
tag = args.tag
|
| 293 |
+
out_dir = Path(args.out) if args.out else (REPORTS_ROOT / dataset / f"{pd.Timestamp.now():%Y%m%d_%H%M%S} {tag}")
|
| 294 |
+
_ensure_dir(out_dir)
|
| 295 |
+
|
| 296 |
+
# Create report
|
| 297 |
+
_write_report(
|
| 298 |
+
out_dir=out_dir,
|
| 299 |
+
tag=tag,
|
| 300 |
+
dataset=dataset,
|
| 301 |
+
k=args.k if args.include_compare else None,
|
| 302 |
+
include_compare=args.include_compare,
|
| 303 |
+
top_n=args.top_n,
|
| 304 |
+
prefer_faiss=args.prefer_faiss,
|
| 305 |
+
metrics_csv=Path(args.metrics_csv),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
print(f"→ Assembling report at {out_dir}")
|
| 309 |
+
print(f"✓ Report ready: {out_dir}")
|
| 310 |
+
|
| 311 |
+
if args.zip:
|
| 312 |
+
import shutil
|
| 313 |
+
zpath = out_dir.with_suffix(".zip")
|
| 314 |
+
base = out_dir.name
|
| 315 |
+
shutil.make_archive(str(zpath.with_suffix("")), "zip", out_dir.parent, base)
|
| 316 |
+
print(f"📦 Zipped → {zpath}")
|
| 317 |
+
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
main()
|
src/agents/run_agent.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# scripts/run_agent.py
|
| 2 |
+
import argparse
|
| 3 |
+
from agents.data_agent import DataAgent
|
| 4 |
+
from agents.index_agent import IndexAgent
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
parser.add_argument("--intent", required=True, choices=["prepare", "index"], help="What the agent should do")
|
| 9 |
+
parser.add_argument("--dataset", default="beauty", help="Dataset name")
|
| 10 |
+
parser.add_argument("--fusion", choices=["concat", "weighted"], help="Fusion mode (for indexing)")
|
| 11 |
+
parser.add_argument("--w_text", type=float, default=1.0, help="Weight for text embeddings")
|
| 12 |
+
parser.add_argument("--w_image", type=float, default=1.0, help="Weight for image embeddings")
|
| 13 |
+
parser.add_argument("--w_meta", type=float, default=1.0, help="Weight for meta embeddings")
|
| 14 |
+
parser.add_argument("--faiss_name", default="default_index", help="Name for the FAISS index output")
|
| 15 |
+
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
|
| 18 |
+
if args.intent == "prepare":
|
| 19 |
+
DataAgent().prepare(args.dataset)
|
| 20 |
+
elif args.intent == "index":
|
| 21 |
+
IndexAgent().index(
|
| 22 |
+
dataset=args.dataset,
|
| 23 |
+
fusion=args.fusion,
|
| 24 |
+
w_text=args.w_text,
|
| 25 |
+
w_image=args.w_image,
|
| 26 |
+
w_meta=args.w_meta,
|
| 27 |
+
faiss_name=args.faiss_name
|
| 28 |
+
)
|
src/cove/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
|
File without changes
|
src/cove/.ipynb_checkpoints/fuse_index-checkpoint.py
ADDED
|
File without changes
|
src/cove/.ipynb_checkpoints/io-checkpoint.py
ADDED
|
File without changes
|
src/cove/__init__.py
ADDED
|
File without changes
|
src/cove/fuse_index.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/cove/fuse_index.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
import numpy as np
|
| 6 |
+
import faiss # pip install faiss-cpu
|
| 7 |
+
from .io import read_item_parquet, align_by_ids
|
| 8 |
+
|
| 9 |
+
def l2norm_rows(M: np.ndarray) -> np.ndarray:
|
| 10 |
+
return M / (np.linalg.norm(M, axis=1, keepdims=True) + 1e-12)
|
| 11 |
+
|
| 12 |
+
def concat_fuse(parts: Tuple[np.ndarray, ...], weights: Tuple[float, ...]) -> np.ndarray:
|
| 13 |
+
scaled = []
|
| 14 |
+
for X, w in zip(parts, weights):
|
| 15 |
+
if X is None or X.size == 0 or w == 0.0:
|
| 16 |
+
continue
|
| 17 |
+
scaled.append(w * X)
|
| 18 |
+
if not scaled:
|
| 19 |
+
raise ValueError("Nothing to fuse.")
|
| 20 |
+
return np.concatenate(scaled, axis=1).astype(np.float32)
|
| 21 |
+
|
| 22 |
+
def weighted_sum(Vt: np.ndarray,
|
| 23 |
+
Vi: Optional[np.ndarray],
|
| 24 |
+
Vm: Optional[np.ndarray],
|
| 25 |
+
wt=1.0, wi=0.0, wm=0.0) -> np.ndarray:
|
| 26 |
+
parts = [wt * Vt]
|
| 27 |
+
D = Vt.shape[1]
|
| 28 |
+
if Vi is not None and wi != 0.0:
|
| 29 |
+
if Vi.shape[1] != D:
|
| 30 |
+
raise ValueError("Weighted-sum requires equal dims.")
|
| 31 |
+
parts.append(wi * Vi)
|
| 32 |
+
if Vm is not None and wm != 0.0:
|
| 33 |
+
if Vm.shape[1] != D:
|
| 34 |
+
raise ValueError("Weighted-sum requires equal dims.")
|
| 35 |
+
parts.append(wm * Vm)
|
| 36 |
+
return np.sum(parts, axis=0).astype(np.float32)
|
| 37 |
+
|
| 38 |
+
def build_ivfpq(V: np.ndarray, nlist=2048, m=32, nbits=8, use_opq: bool=False):
|
| 39 |
+
dim = V.shape[1]
|
| 40 |
+
Vn = l2norm_rows(V)
|
| 41 |
+
|
| 42 |
+
# optional OPQ: orthogonal rotation before PQ
|
| 43 |
+
opq = None
|
| 44 |
+
if use_opq:
|
| 45 |
+
opq = faiss.OPQMatrix(dim, m)
|
| 46 |
+
opq.train(Vn)
|
| 47 |
+
Vn = opq.apply_py(Vn)
|
| 48 |
+
|
| 49 |
+
quantizer = faiss.IndexFlatIP(dim)
|
| 50 |
+
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits, faiss.METRIC_INNER_PRODUCT)
|
| 51 |
+
index.train(Vn)
|
| 52 |
+
index.add(Vn)
|
| 53 |
+
return index, opq
|
| 54 |
+
|
| 55 |
+
def build_pq(V: np.ndarray, m=32, nbits=8, use_opq: bool=False):
|
| 56 |
+
dim = V.shape[1]
|
| 57 |
+
Vn = l2norm_rows(V)
|
| 58 |
+
opq = None
|
| 59 |
+
if use_opq:
|
| 60 |
+
opq = faiss.OPQMatrix(dim, m)
|
| 61 |
+
opq.train(Vn)
|
| 62 |
+
Vn = opq.apply_py(Vn)
|
| 63 |
+
index = faiss.IndexPQ(dim, m, nbits, faiss.METRIC_INNER_PRODUCT)
|
| 64 |
+
index.train(Vn)
|
| 65 |
+
index.add(Vn)
|
| 66 |
+
return index, opq
|
| 67 |
+
|
| 68 |
+
def save_index(out_dir: Path, base: str, index, item_ids, opq=None):
|
| 69 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
faiss.write_index(index, str(out_dir / f"{base}.faiss"))
|
| 71 |
+
np.save(out_dir / f"{base}.npy", np.array(item_ids, dtype=object))
|
| 72 |
+
if opq is not None:
|
| 73 |
+
faiss.write_VectorTransform(opq, str(out_dir / f"{base}.opq"))
|
| 74 |
+
|
| 75 |
+
def fuse_mm_with_cove(
|
| 76 |
+
proc_dir: Path,
|
| 77 |
+
cove_fp: Path,
|
| 78 |
+
fusion: str = "concat",
|
| 79 |
+
w_text=1.0, w_image=0.0, w_meta=0.0, w_cove=1.0,
|
| 80 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 81 |
+
# base item ids & text vectors (your standard files)
|
| 82 |
+
I_text_ids, Vt = read_item_parquet(proc_dir / "item_text_emb.parquet")
|
| 83 |
+
item_ids = I_text_ids # master order
|
| 84 |
+
|
| 85 |
+
# align optional parts
|
| 86 |
+
Vi = Vm = None
|
| 87 |
+
if (proc_dir / "item_image_emb.parquet").exists():
|
| 88 |
+
I_img_ids, Vi_raw = read_item_parquet(proc_dir / "item_image_emb.parquet")
|
| 89 |
+
Vi = align_by_ids(item_ids, I_img_ids, Vi_raw)
|
| 90 |
+
if (proc_dir / "item_meta_emb.parquet").exists():
|
| 91 |
+
I_met_ids, Vm_raw = read_item_parquet(proc_dir / "item_meta_emb.parquet")
|
| 92 |
+
Vm = align_by_ids(item_ids, I_met_ids, Vm_raw)
|
| 93 |
+
|
| 94 |
+
# CoVE item vectors (already trained elsewhere)
|
| 95 |
+
C_ids, Vc_raw = read_item_parquet(cove_fp)
|
| 96 |
+
Vc = align_by_ids(item_ids, C_ids, Vc_raw)
|
| 97 |
+
|
| 98 |
+
if fusion == "concat":
|
| 99 |
+
V = concat_fuse((Vt, Vi, Vm, Vc), (w_text, w_image, w_meta, w_cove))
|
| 100 |
+
elif fusion == "weighted":
|
| 101 |
+
# weighted-sum MM first (requires same dim), then concat CoVE (or sum if same dim)
|
| 102 |
+
Vmm = concat_fuse((Vt, Vi, Vm), (w_text, w_image, w_meta)) # safe concat
|
| 103 |
+
V = concat_fuse((Vmm, Vc), (1.0, w_cove))
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError("fusion must be 'concat' or 'weighted'")
|
| 106 |
+
return item_ids, V
|
src/cove/io.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/cove/io.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Tuple, Optional, List, Dict
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
def read_item_parquet(fp: Path, id_col="item_id", vec_col="vector") -> Tuple[np.ndarray, np.ndarray]:
|
| 9 |
+
df = pd.read_parquet(fp)
|
| 10 |
+
ids = df[id_col].to_numpy()
|
| 11 |
+
vecs = np.stack(df[vec_col].to_numpy()).astype(np.float32)
|
| 12 |
+
return ids, vecs
|
| 13 |
+
|
| 14 |
+
def align_by_ids(base_ids: np.ndarray,
|
| 15 |
+
other_ids: np.ndarray,
|
| 16 |
+
other_vecs: np.ndarray,
|
| 17 |
+
dim: Optional[int] = None) -> np.ndarray:
|
| 18 |
+
"""Return matrix aligned to base_ids; missing rows -> zeros."""
|
| 19 |
+
m: Dict[str, np.ndarray] = {str(i): v for i, v in zip(other_ids, other_vecs)}
|
| 20 |
+
if dim is None:
|
| 21 |
+
# infer from first vector; if none, return zeros
|
| 22 |
+
a_vec = next(iter(m.values()), None)
|
| 23 |
+
dim = len(a_vec) if a_vec is not None else 0
|
| 24 |
+
out = np.zeros((len(base_ids), dim), dtype=np.float32)
|
| 25 |
+
for r, iid in enumerate(base_ids):
|
| 26 |
+
v = m.get(str(iid))
|
| 27 |
+
if v is not None:
|
| 28 |
+
out[r] = v
|
| 29 |
+
return out
|
src/data/.ipynb_checkpoints/init-checkpoint.py
ADDED
|
File without changes
|
src/data/.ipynb_checkpoints/loader-checkpoint.py
ADDED
|
File without changes
|
src/data/.ipynb_checkpoints/registry-checkpoint.py
ADDED
|
File without changes
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Keep this light to avoid pulling heavy modules at import time
|
| 2 |
+
__all__ = []
|
src/data/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (166 Bytes). View file
|
|
|
src/data/__pycache__/loader.cpython-311.pyc
ADDED
|
Binary file (1.1 kB). View file
|
|
|
src/data/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (2.67 kB). View file
|
|
|
src/data/loader.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/data/loader.py
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
def load_dataset(dataset: str):
|
| 7 |
+
base_path = Path("data/processed") / dataset
|
| 8 |
+
|
| 9 |
+
with open(base_path / "seq.json", "r") as f:
|
| 10 |
+
user_seqs = json.load(f) # this keeps the full {user_id: [item_id, ...]} dict
|
| 11 |
+
|
| 12 |
+
with open(base_path / "candidate_items.json", "r") as f:
|
| 13 |
+
candidate_items = json.load(f)
|
| 14 |
+
|
| 15 |
+
return user_seqs, candidate_items
|
src/data/registry.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/data/registry.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict
|
| 6 |
+
|
| 7 |
+
# Canonical path helpers live in utils.paths
|
| 8 |
+
from utils.paths import (
|
| 9 |
+
RAW_DIR,
|
| 10 |
+
PROCESSED_DIR,
|
| 11 |
+
get_dataset_paths as _get_dataset_paths, # returns dict[str, Path]
|
| 12 |
+
get_raw_path,
|
| 13 |
+
get_processed_path,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_paths(dataset: str) -> Dict[str, Path]:
|
| 18 |
+
"""
|
| 19 |
+
Return raw and processed directories for a dataset name (as Path objects).
|
| 20 |
+
Creates them if they do not exist.
|
| 21 |
+
|
| 22 |
+
Example:
|
| 23 |
+
d = get_paths("beauty")
|
| 24 |
+
d["raw_dir"] -> Path(.../data/raw/beauty)
|
| 25 |
+
d["processed_dir"] -> Path(.../data/processed/beauty)
|
| 26 |
+
"""
|
| 27 |
+
name = (dataset or "").lower()
|
| 28 |
+
raw_dir = RAW_DIR / name
|
| 29 |
+
processed_dir = PROCESSED_DIR / name
|
| 30 |
+
raw_dir.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
processed_dir.mkdir(parents=True, exist_ok=True)
|
| 32 |
+
return {"raw_dir": raw_dir, "processed_dir": processed_dir}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def raw_file(dataset: str, filename: str) -> Path:
|
| 36 |
+
"""Convenience: Path to a file inside data/raw/<dataset>/"""
|
| 37 |
+
return get_paths(dataset)["raw_dir"] / filename
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def processed_file(dataset: str, filename: str) -> Path:
|
| 41 |
+
"""Convenience: Path to a file inside data/processed/<dataset>/"""
|
| 42 |
+
return get_paths(dataset)["processed_dir"] / filename
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------
|
| 46 |
+
# Compatibility shim used by older code/tests:
|
| 47 |
+
# This now returns Path objects instead of strings.
|
| 48 |
+
# ---------------------------------------------------------------------
|
| 49 |
+
def get_dataset_paths(dataset: str) -> Dict[str, Path]:
|
| 50 |
+
"""
|
| 51 |
+
Returns absolute paths (as Path objects) for the given dataset:
|
| 52 |
+
{
|
| 53 |
+
"raw": Path(.../data/raw/<dataset>),
|
| 54 |
+
"processed": Path(.../data/processed/<dataset>),
|
| 55 |
+
"cache": Path(.../data/cache/<dataset>),
|
| 56 |
+
"logs": Path(.../logs),
|
| 57 |
+
"meta_features_path": Path(.../meta_features.npy),
|
| 58 |
+
"text_features_path": Path(.../text_features.npy),
|
| 59 |
+
"image_features_path": Path(.../image_features.npy),
|
| 60 |
+
"labels_path": Path(.../labels.json)
|
| 61 |
+
}
|
| 62 |
+
"""
|
| 63 |
+
return _get_dataset_paths(dataset)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
__all__ = [
|
| 67 |
+
"get_paths",
|
| 68 |
+
"raw_file",
|
| 69 |
+
"processed_file",
|
| 70 |
+
"get_dataset_paths", # keep public for tests
|
| 71 |
+
"get_raw_path",
|
| 72 |
+
"get_processed_path",
|
| 73 |
+
]
|
src/models/.ipynb_checkpoints/fusion-checkpoint.py
ADDED
|
File without changes
|