gyubin02 commited on
Commit
fed41a9
·
1 Parent(s): 1999995
Files changed (4) hide show
  1. .env.example +3 -0
  2. inference.py +72 -0
  3. requirements.txt +2 -0
  4. train.py +42 -18
.env.example CHANGED
@@ -1,6 +1,9 @@
1
  # Nexon Open API key
2
  NEXON_API_KEY=
3
 
 
 
 
4
  # Optional output locations
5
  OUTPUT_DIR=data
6
  DB_PATH=
 
1
  # Nexon Open API key
2
  NEXON_API_KEY=
3
 
4
+ # Hugging Face token (for private/gated models)
5
+ HUGGINGFACE_HUB_TOKEN=
6
+
7
  # Optional output locations
8
  OUTPUT_DIR=data
9
  DB_PATH=
inference.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from peft import PeftModel
10
+ from transformers import SiglipModel, SiglipProcessor
11
+
12
+
13
+ def parse_args() -> argparse.Namespace:
14
+ parser = argparse.ArgumentParser(description="SigLIP inference with LoRA adapter.")
15
+ parser.add_argument("--model-id", default="google/siglip-base-patch16-256-multilingual")
16
+ parser.add_argument("--adapter-path", default="outputs/ko-clip-lora/best_model")
17
+ parser.add_argument("--image-path", required=True, type=Path)
18
+ parser.add_argument(
19
+ "--candidates",
20
+ nargs="+",
21
+ default=[
22
+ "레인보우 스타",
23
+ "블랙과 흰색의 별 모양 무기",
24
+ "하얀 모자",
25
+ "눈",
26
+ "관련 없는 이미지",
27
+ ],
28
+ help="List of text candidates (Korean recommended).",
29
+ )
30
+ return parser.parse_args()
31
+
32
+
33
+ def main() -> None:
34
+ args = parse_args()
35
+
36
+ if not args.image_path.exists():
37
+ raise FileNotFoundError(f"Image not found: {args.image_path}")
38
+
39
+ print("Loading model...")
40
+ base_model = SiglipModel.from_pretrained(args.model_id)
41
+ model = PeftModel.from_pretrained(base_model, args.adapter_path)
42
+ processor = SiglipProcessor.from_pretrained(args.model_id)
43
+
44
+ device = "cuda" if torch.cuda.is_available() else "cpu"
45
+ model.to(device)
46
+ model.eval()
47
+
48
+ image = Image.open(args.image_path).convert("RGB")
49
+ image_inputs = processor(images=image, return_tensors="pt").to(device)
50
+ text_inputs = processor(text=args.candidates, return_tensors="pt", padding=True).to(device)
51
+
52
+ print(f"\nTarget Image: {args.image_path}")
53
+ print("-" * 30)
54
+
55
+ with torch.no_grad():
56
+ image_embeds = model.get_image_features(**image_inputs)
57
+ text_embeds = model.get_text_features(**text_inputs)
58
+
59
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
60
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
61
+
62
+ logits = image_embeds @ text_embeds.t()
63
+ logit_scale = model.logit_scale.exp()
64
+ logits = logits * logit_scale
65
+ probs = logits.softmax(dim=1)
66
+
67
+ for text, prob in zip(args.candidates, probs[0]):
68
+ print(f"{text}: {prob.item() * 100:.2f}%")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
requirements.txt CHANGED
@@ -7,5 +7,7 @@ pillow>=10.0
7
  pyarrow>=14.0
8
  torch>=2.1
9
  transformers>=4.41
 
 
10
  peft>=0.11
11
  scikit-learn>=1.3
 
7
  pyarrow>=14.0
8
  torch>=2.1
9
  transformers>=4.41
10
+ sentencepiece>=0.1.99
11
+ protobuf>=4.21
12
  peft>=0.11
13
  scikit-learn>=1.3
train.py CHANGED
@@ -14,11 +14,11 @@ from PIL import Image
14
  from peft import LoraConfig, TaskType, get_peft_model
15
  from sklearn.model_selection import train_test_split
16
  from torch.utils.data import DataLoader, Dataset
17
- from transformers import CLIPModel, CLIPProcessor
18
 
19
 
20
  class CustomDataset(Dataset):
21
- def __init__(self, records: list[dict[str, Any]], processor: CLIPProcessor, max_length: int) -> None:
22
  self.records = records
23
  self.image_processor = processor.image_processor
24
  self.tokenizer = processor.tokenizer
@@ -41,12 +41,20 @@ class CustomDataset(Dataset):
41
  padding="max_length",
42
  truncation=True,
43
  max_length=self.max_length,
 
44
  )
45
 
 
 
 
 
 
 
 
46
  return {
47
  "pixel_values": image_inputs["pixel_values"][0],
48
- "input_ids": text_inputs["input_ids"][0],
49
- "attention_mask": text_inputs["attention_mask"][0],
50
  }
51
 
52
 
@@ -95,9 +103,9 @@ def prepare_model_and_processor(
95
  lora_r: int,
96
  lora_alpha: int,
97
  lora_dropout: float,
98
- ) -> tuple[CLIPModel, CLIPProcessor]:
99
- processor = CLIPProcessor.from_pretrained(model_id)
100
- base_model = CLIPModel.from_pretrained(model_id)
101
  for param in base_model.parameters():
102
  param.requires_grad = False
103
 
@@ -114,9 +122,13 @@ def prepare_model_and_processor(
114
  return model, processor
115
 
116
 
117
- def clip_contrastive_loss(model: CLIPModel, outputs) -> torch.Tensor:
118
- image_embeds = outputs.image_embeds
119
- text_embeds = outputs.text_embeds
 
 
 
 
120
  logit_scale = model.logit_scale.exp().clamp(max=100)
121
  logits_per_text = logit_scale * text_embeds @ image_embeds.t()
122
  logits_per_image = logits_per_text.t()
@@ -128,7 +140,7 @@ def clip_contrastive_loss(model: CLIPModel, outputs) -> torch.Tensor:
128
 
129
  @torch.no_grad()
130
  def evaluate(
131
- model: CLIPModel,
132
  data_loader: DataLoader,
133
  device: torch.device,
134
  autocast_context,
@@ -139,8 +151,12 @@ def evaluate(
139
  for batch in data_loader:
140
  batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
141
  with autocast_context:
142
- outputs = model(**batch)
143
- loss = clip_contrastive_loss(model, outputs)
 
 
 
 
144
  total_loss += loss.item()
145
  steps += 1
146
  return total_loss / max(steps, 1)
@@ -148,8 +164,8 @@ def evaluate(
148
 
149
  @torch.no_grad()
150
  def run_similarity_test(
151
- model: CLIPModel,
152
- processor: CLIPProcessor,
153
  sample: dict[str, Any],
154
  device: torch.device,
155
  autocast_context,
@@ -205,7 +221,11 @@ def parse_args() -> argparse.Namespace:
205
  help="Root directory for relative image paths.",
206
  )
207
  parser.add_argument("--output-dir", type=Path, default=Path("outputs/ko-clip-lora"))
208
- parser.add_argument("--model-id", type=str, default="tech-leader/ko-clip-base-v1-vit-b-32")
 
 
 
 
209
  parser.add_argument("--epochs", type=int, default=10)
210
  parser.add_argument(
211
  "--batch-size",
@@ -302,8 +322,12 @@ def main() -> None:
302
  for step, batch in enumerate(train_loader, start=1):
303
  batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
304
  with autocast_context:
305
- outputs = model(**batch)
306
- loss = clip_contrastive_loss(model, outputs)
 
 
 
 
307
  total_loss += loss.item()
308
  loss = loss / args.grad_accum_steps
309
  loss.backward()
 
14
  from peft import LoraConfig, TaskType, get_peft_model
15
  from sklearn.model_selection import train_test_split
16
  from torch.utils.data import DataLoader, Dataset
17
+ from transformers import SiglipModel, SiglipProcessor
18
 
19
 
20
  class CustomDataset(Dataset):
21
+ def __init__(self, records: list[dict[str, Any]], processor: SiglipProcessor, max_length: int) -> None:
22
  self.records = records
23
  self.image_processor = processor.image_processor
24
  self.tokenizer = processor.tokenizer
 
41
  padding="max_length",
42
  truncation=True,
43
  max_length=self.max_length,
44
+ return_attention_mask=True,
45
  )
46
 
47
+ input_ids = text_inputs["input_ids"][0]
48
+ if "attention_mask" in text_inputs:
49
+ attention_mask = text_inputs["attention_mask"][0]
50
+ else:
51
+ pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
52
+ attention_mask = (input_ids != pad_id).long()
53
+
54
  return {
55
  "pixel_values": image_inputs["pixel_values"][0],
56
+ "input_ids": input_ids,
57
+ "attention_mask": attention_mask,
58
  }
59
 
60
 
 
103
  lora_r: int,
104
  lora_alpha: int,
105
  lora_dropout: float,
106
+ ) -> tuple[SiglipModel, SiglipProcessor]:
107
+ processor = SiglipProcessor.from_pretrained(model_id)
108
+ base_model = SiglipModel.from_pretrained(model_id)
109
  for param in base_model.parameters():
110
  param.requires_grad = False
111
 
 
122
  return model, processor
123
 
124
 
125
+ def clip_contrastive_loss(
126
+ model: SiglipModel,
127
+ image_embeds: torch.Tensor,
128
+ text_embeds: torch.Tensor,
129
+ ) -> torch.Tensor:
130
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
131
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
132
  logit_scale = model.logit_scale.exp().clamp(max=100)
133
  logits_per_text = logit_scale * text_embeds @ image_embeds.t()
134
  logits_per_image = logits_per_text.t()
 
140
 
141
  @torch.no_grad()
142
  def evaluate(
143
+ model: SiglipModel,
144
  data_loader: DataLoader,
145
  device: torch.device,
146
  autocast_context,
 
151
  for batch in data_loader:
152
  batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
153
  with autocast_context:
154
+ image_embeds = model.get_image_features(pixel_values=batch["pixel_values"])
155
+ text_embeds = model.get_text_features(
156
+ input_ids=batch["input_ids"],
157
+ attention_mask=batch["attention_mask"],
158
+ )
159
+ loss = clip_contrastive_loss(model, image_embeds, text_embeds)
160
  total_loss += loss.item()
161
  steps += 1
162
  return total_loss / max(steps, 1)
 
164
 
165
  @torch.no_grad()
166
  def run_similarity_test(
167
+ model: SiglipModel,
168
+ processor: SiglipProcessor,
169
  sample: dict[str, Any],
170
  device: torch.device,
171
  autocast_context,
 
221
  help="Root directory for relative image paths.",
222
  )
223
  parser.add_argument("--output-dir", type=Path, default=Path("outputs/ko-clip-lora"))
224
+ parser.add_argument(
225
+ "--model-id",
226
+ type=str,
227
+ default="google/siglip-base-patch16-256-multilingual",
228
+ )
229
  parser.add_argument("--epochs", type=int, default=10)
230
  parser.add_argument(
231
  "--batch-size",
 
322
  for step, batch in enumerate(train_loader, start=1):
323
  batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
324
  with autocast_context:
325
+ image_embeds = model.get_image_features(pixel_values=batch["pixel_values"])
326
+ text_embeds = model.get_text_features(
327
+ input_ids=batch["input_ids"],
328
+ attention_mask=batch["attention_mask"],
329
+ )
330
+ loss = clip_contrastive_loss(model, image_embeds, text_embeds)
331
  total_loss += loss.item()
332
  loss = loss / args.grad_accum_steps
333
  loss.backward()