gyubin02 commited on
Commit
619dbf0
ยท
1 Parent(s): 3a167c5
Files changed (4) hide show
  1. indexer.py +4 -1
  2. inference.py +17 -8
  3. main.py +7 -2
  4. requirements.txt +3 -0
indexer.py CHANGED
@@ -133,7 +133,10 @@ def main() -> None:
133
  model.eval()
134
 
135
  client = chromadb.PersistentClient(path=str(args.chroma_path))
136
- collection = client.get_or_create_collection(name=args.collection)
 
 
 
137
 
138
  total_images = len(image_paths)
139
  progress = tqdm(total=total_images, desc="Indexing images", unit="img")
 
133
  model.eval()
134
 
135
  client = chromadb.PersistentClient(path=str(args.chroma_path))
136
+ collection = client.get_or_create_collection(
137
+ name=args.collection,
138
+ metadata={"hnsw:space": "cosine"},
139
+ )
140
 
141
  total_images = len(image_paths)
142
  progress = tqdm(total=total_images, desc="Indexing images", unit="img")
inference.py CHANGED
@@ -21,7 +21,7 @@ def parse_args() -> argparse.Namespace:
21
  default=[
22
  "๋ ˆ์ธ๋ณด์šฐ ์Šคํƒ€",
23
  "๋ธ”๋ž™๊ณผ ํฐ์ƒ‰์˜ ๋ณ„ ๋ชจ์–‘ ๋ฌด๊ธฐ",
24
- "ํ•‘ํฌ์ƒ‰ ๋‚ ๊ฐœ",
25
  "ํŒŒ๋ž€์ƒ‰ ๋ชจ์ž",
26
  "๊ด€๋ จ ์—†๋Š” ์ด๋ฏธ์ง€",
27
  ],
@@ -30,6 +30,15 @@ def parse_args() -> argparse.Namespace:
30
  return parser.parse_args()
31
 
32
 
 
 
 
 
 
 
 
 
 
33
  def main() -> None:
34
  args = parse_args()
35
 
@@ -38,7 +47,8 @@ def main() -> None:
38
 
39
  print("Loading model...")
40
  base_model = SiglipModel.from_pretrained(args.model_id)
41
- model = PeftModel.from_pretrained(base_model, args.adapter_path)
 
42
  processor = SiglipProcessor.from_pretrained(args.model_id)
43
 
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -59,13 +69,12 @@ def main() -> None:
59
  image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
60
  text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
61
 
62
- logits = image_embeds @ text_embeds.t()
63
- logit_scale = model.logit_scale.exp()
64
- logits = logits * logit_scale
65
- probs = logits.softmax(dim=1)
66
 
67
- for text, prob in zip(args.candidates, probs[0]):
68
- print(f"{text}: {prob.item() * 100:.2f}%")
 
 
69
 
70
 
71
  if __name__ == "__main__":
 
21
  default=[
22
  "๋ ˆ์ธ๋ณด์šฐ ์Šคํƒ€",
23
  "๋ธ”๋ž™๊ณผ ํฐ์ƒ‰์˜ ๋ณ„ ๋ชจ์–‘ ๋ฌด๊ธฐ",
24
+ "ํฐ์ƒ‰ ํ‹ฐ์…”์ธ ",
25
  "ํŒŒ๋ž€์ƒ‰ ๋ชจ์ž",
26
  "๊ด€๋ จ ์—†๋Š” ์ด๋ฏธ์ง€",
27
  ],
 
30
  return parser.parse_args()
31
 
32
 
33
+ def resolve_adapter_path(adapter_path: Path) -> Path:
34
+ if (adapter_path / "adapter_config.json").exists():
35
+ return adapter_path
36
+ candidate = adapter_path / "best_model"
37
+ if (candidate / "adapter_config.json").exists():
38
+ return candidate
39
+ return adapter_path
40
+
41
+
42
  def main() -> None:
43
  args = parse_args()
44
 
 
47
 
48
  print("Loading model...")
49
  base_model = SiglipModel.from_pretrained(args.model_id)
50
+ adapter_path = resolve_adapter_path(Path(args.adapter_path))
51
+ model = PeftModel.from_pretrained(base_model, str(adapter_path))
52
  processor = SiglipProcessor.from_pretrained(args.model_id)
53
 
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
69
  image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
70
  text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
71
 
72
+ similarities = image_embeds @ text_embeds.t()
 
 
 
73
 
74
+ for text, similarity in zip(args.candidates, similarities[0]):
75
+ sim_value = similarity.item()
76
+ distance = 1.0 - sim_value
77
+ print(f"{text} Similarity: {sim_value:.4f} | Distance: {distance:.4f}")
78
 
79
 
80
  if __name__ == "__main__":
main.py CHANGED
@@ -49,7 +49,10 @@ async def lifespan(app: FastAPI):
49
  model.eval()
50
 
51
  client = chromadb.PersistentClient(path="chroma_db")
52
- collection = client.get_or_create_collection(name="maple_items")
 
 
 
53
 
54
  app.state.device = device
55
  app.state.model = model
@@ -97,7 +100,7 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
97
  results = collection.query(
98
  query_embeddings=[query_embedding],
99
  n_results=payload.k,
100
- include=["distances", "metadatas", "ids"],
101
  )
102
 
103
  ids: List[str] = results.get("ids", [[]])[0]
@@ -110,11 +113,13 @@ def search(payload: SearchRequest) -> Dict[str, Any]:
110
  if metadata:
111
  filepath = metadata.get("filepath", "")
112
  image_url = f"/static/images/{filepath}" if filepath else ""
 
113
  response_items.append(
114
  {
115
  "id": item_id,
116
  "filepath": filepath,
117
  "distance": distance,
 
118
  "image_url": image_url,
119
  }
120
  )
 
49
  model.eval()
50
 
51
  client = chromadb.PersistentClient(path="chroma_db")
52
+ collection = client.get_or_create_collection(
53
+ name="maple_items",
54
+ metadata={"hnsw:space": "cosine"},
55
+ )
56
 
57
  app.state.device = device
58
  app.state.model = model
 
100
  results = collection.query(
101
  query_embeddings=[query_embedding],
102
  n_results=payload.k,
103
+ include=["distances", "metadatas"],
104
  )
105
 
106
  ids: List[str] = results.get("ids", [[]])[0]
 
113
  if metadata:
114
  filepath = metadata.get("filepath", "")
115
  image_url = f"/static/images/{filepath}" if filepath else ""
116
+ similarity = max(0.0, 1.0 - distance) if distance is not None else 0.0
117
  response_items.append(
118
  {
119
  "id": item_id,
120
  "filepath": filepath,
121
  "distance": distance,
122
+ "similarity": similarity,
123
  "image_url": image_url,
124
  }
125
  )
requirements.txt CHANGED
@@ -11,3 +11,6 @@ sentencepiece>=0.1.99
11
  protobuf>=4.21
12
  peft>=0.11
13
  scikit-learn>=1.3
 
 
 
 
11
  protobuf>=4.21
12
  peft>=0.11
13
  scikit-learn>=1.3
14
+ chromadb>=0.5
15
+ fastapi>=0.110
16
+ uvicorn>=0.29