Spaces:
Sleeping
Sleeping
- indexer.py +4 -1
- inference.py +17 -8
- main.py +7 -2
- requirements.txt +3 -0
indexer.py
CHANGED
|
@@ -133,7 +133,10 @@ def main() -> None:
|
|
| 133 |
model.eval()
|
| 134 |
|
| 135 |
client = chromadb.PersistentClient(path=str(args.chroma_path))
|
| 136 |
-
collection = client.get_or_create_collection(
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
total_images = len(image_paths)
|
| 139 |
progress = tqdm(total=total_images, desc="Indexing images", unit="img")
|
|
|
|
| 133 |
model.eval()
|
| 134 |
|
| 135 |
client = chromadb.PersistentClient(path=str(args.chroma_path))
|
| 136 |
+
collection = client.get_or_create_collection(
|
| 137 |
+
name=args.collection,
|
| 138 |
+
metadata={"hnsw:space": "cosine"},
|
| 139 |
+
)
|
| 140 |
|
| 141 |
total_images = len(image_paths)
|
| 142 |
progress = tqdm(total=total_images, desc="Indexing images", unit="img")
|
inference.py
CHANGED
|
@@ -21,7 +21,7 @@ def parse_args() -> argparse.Namespace:
|
|
| 21 |
default=[
|
| 22 |
"๋ ์ธ๋ณด์ฐ ์คํ",
|
| 23 |
"๋ธ๋๊ณผ ํฐ์์ ๋ณ ๋ชจ์ ๋ฌด๊ธฐ",
|
| 24 |
-
"
|
| 25 |
"ํ๋์ ๋ชจ์",
|
| 26 |
"๊ด๋ จ ์๋ ์ด๋ฏธ์ง",
|
| 27 |
],
|
|
@@ -30,6 +30,15 @@ def parse_args() -> argparse.Namespace:
|
|
| 30 |
return parser.parse_args()
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def main() -> None:
|
| 34 |
args = parse_args()
|
| 35 |
|
|
@@ -38,7 +47,8 @@ def main() -> None:
|
|
| 38 |
|
| 39 |
print("Loading model...")
|
| 40 |
base_model = SiglipModel.from_pretrained(args.model_id)
|
| 41 |
-
|
|
|
|
| 42 |
processor = SiglipProcessor.from_pretrained(args.model_id)
|
| 43 |
|
| 44 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -59,13 +69,12 @@ def main() -> None:
|
|
| 59 |
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
| 60 |
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
| 61 |
|
| 62 |
-
|
| 63 |
-
logit_scale = model.logit_scale.exp()
|
| 64 |
-
logits = logits * logit_scale
|
| 65 |
-
probs = logits.softmax(dim=1)
|
| 66 |
|
| 67 |
-
for text,
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
if __name__ == "__main__":
|
|
|
|
| 21 |
default=[
|
| 22 |
"๋ ์ธ๋ณด์ฐ ์คํ",
|
| 23 |
"๋ธ๋๊ณผ ํฐ์์ ๋ณ ๋ชจ์ ๋ฌด๊ธฐ",
|
| 24 |
+
"ํฐ์ ํฐ์
์ธ ",
|
| 25 |
"ํ๋์ ๋ชจ์",
|
| 26 |
"๊ด๋ จ ์๋ ์ด๋ฏธ์ง",
|
| 27 |
],
|
|
|
|
| 30 |
return parser.parse_args()
|
| 31 |
|
| 32 |
|
| 33 |
+
def resolve_adapter_path(adapter_path: Path) -> Path:
|
| 34 |
+
if (adapter_path / "adapter_config.json").exists():
|
| 35 |
+
return adapter_path
|
| 36 |
+
candidate = adapter_path / "best_model"
|
| 37 |
+
if (candidate / "adapter_config.json").exists():
|
| 38 |
+
return candidate
|
| 39 |
+
return adapter_path
|
| 40 |
+
|
| 41 |
+
|
| 42 |
def main() -> None:
|
| 43 |
args = parse_args()
|
| 44 |
|
|
|
|
| 47 |
|
| 48 |
print("Loading model...")
|
| 49 |
base_model = SiglipModel.from_pretrained(args.model_id)
|
| 50 |
+
adapter_path = resolve_adapter_path(Path(args.adapter_path))
|
| 51 |
+
model = PeftModel.from_pretrained(base_model, str(adapter_path))
|
| 52 |
processor = SiglipProcessor.from_pretrained(args.model_id)
|
| 53 |
|
| 54 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 69 |
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
|
| 70 |
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
|
| 71 |
|
| 72 |
+
similarities = image_embeds @ text_embeds.t()
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
for text, similarity in zip(args.candidates, similarities[0]):
|
| 75 |
+
sim_value = similarity.item()
|
| 76 |
+
distance = 1.0 - sim_value
|
| 77 |
+
print(f"{text} Similarity: {sim_value:.4f} | Distance: {distance:.4f}")
|
| 78 |
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
main.py
CHANGED
|
@@ -49,7 +49,10 @@ async def lifespan(app: FastAPI):
|
|
| 49 |
model.eval()
|
| 50 |
|
| 51 |
client = chromadb.PersistentClient(path="chroma_db")
|
| 52 |
-
collection = client.get_or_create_collection(
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
app.state.device = device
|
| 55 |
app.state.model = model
|
|
@@ -97,7 +100,7 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
|
|
| 97 |
results = collection.query(
|
| 98 |
query_embeddings=[query_embedding],
|
| 99 |
n_results=payload.k,
|
| 100 |
-
include=["distances", "metadatas"
|
| 101 |
)
|
| 102 |
|
| 103 |
ids: List[str] = results.get("ids", [[]])[0]
|
|
@@ -110,11 +113,13 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
|
|
| 110 |
if metadata:
|
| 111 |
filepath = metadata.get("filepath", "")
|
| 112 |
image_url = f"/static/images/{filepath}" if filepath else ""
|
|
|
|
| 113 |
response_items.append(
|
| 114 |
{
|
| 115 |
"id": item_id,
|
| 116 |
"filepath": filepath,
|
| 117 |
"distance": distance,
|
|
|
|
| 118 |
"image_url": image_url,
|
| 119 |
}
|
| 120 |
)
|
|
|
|
| 49 |
model.eval()
|
| 50 |
|
| 51 |
client = chromadb.PersistentClient(path="chroma_db")
|
| 52 |
+
collection = client.get_or_create_collection(
|
| 53 |
+
name="maple_items",
|
| 54 |
+
metadata={"hnsw:space": "cosine"},
|
| 55 |
+
)
|
| 56 |
|
| 57 |
app.state.device = device
|
| 58 |
app.state.model = model
|
|
|
|
| 100 |
results = collection.query(
|
| 101 |
query_embeddings=[query_embedding],
|
| 102 |
n_results=payload.k,
|
| 103 |
+
include=["distances", "metadatas"],
|
| 104 |
)
|
| 105 |
|
| 106 |
ids: List[str] = results.get("ids", [[]])[0]
|
|
|
|
| 113 |
if metadata:
|
| 114 |
filepath = metadata.get("filepath", "")
|
| 115 |
image_url = f"/static/images/{filepath}" if filepath else ""
|
| 116 |
+
similarity = max(0.0, 1.0 - distance) if distance is not None else 0.0
|
| 117 |
response_items.append(
|
| 118 |
{
|
| 119 |
"id": item_id,
|
| 120 |
"filepath": filepath,
|
| 121 |
"distance": distance,
|
| 122 |
+
"similarity": similarity,
|
| 123 |
"image_url": image_url,
|
| 124 |
}
|
| 125 |
)
|
requirements.txt
CHANGED
|
@@ -11,3 +11,6 @@ sentencepiece>=0.1.99
|
|
| 11 |
protobuf>=4.21
|
| 12 |
peft>=0.11
|
| 13 |
scikit-learn>=1.3
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
protobuf>=4.21
|
| 12 |
peft>=0.11
|
| 13 |
scikit-learn>=1.3
|
| 14 |
+
chromadb>=0.5
|
| 15 |
+
fastapi>=0.110
|
| 16 |
+
uvicorn>=0.29
|