Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from PIL import Image | |
| from peft import PeftModel | |
| from transformers import SiglipModel, SiglipProcessor | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="SigLIP inference with LoRA adapter.") | |
| parser.add_argument("--model-id", default="google/siglip-base-patch16-256-multilingual") | |
| parser.add_argument("--adapter-path", default="outputs/ko-clip-lora/best_model") | |
| parser.add_argument("--image-path", required=True, type=Path) | |
| parser.add_argument( | |
| "--candidates", | |
| nargs="+", | |
| default=[ | |
| "๋ ์ธ๋ณด์ฐ ์คํ", | |
| "๋ธ๋๊ณผ ํฐ์์ ๋ณ ๋ชจ์ ๋ฌด๊ธฐ", | |
| "ํฐ์ ํฐ์ ์ธ ", | |
| "ํ๋์ ๋ชจ์", | |
| "๊ด๋ จ ์๋ ์ด๋ฏธ์ง", | |
| ], | |
| help="List of text candidates (Korean recommended).", | |
| ) | |
| return parser.parse_args() | |
| def resolve_adapter_path(adapter_path: Path) -> Path: | |
| if (adapter_path / "adapter_config.json").exists(): | |
| return adapter_path | |
| candidate = adapter_path / "best_model" | |
| if (candidate / "adapter_config.json").exists(): | |
| return candidate | |
| return adapter_path | |
| def main() -> None: | |
| args = parse_args() | |
| if not args.image_path.exists(): | |
| raise FileNotFoundError(f"Image not found: {args.image_path}") | |
| print("Loading model...") | |
| base_model = SiglipModel.from_pretrained(args.model_id) | |
| adapter_path = resolve_adapter_path(Path(args.adapter_path)) | |
| model = PeftModel.from_pretrained(base_model, str(adapter_path)) | |
| processor = SiglipProcessor.from_pretrained(args.model_id) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| image = Image.open(args.image_path).convert("RGB") | |
| image_inputs = processor(images=image, return_tensors="pt").to(device) | |
| text_inputs = processor(text=args.candidates, return_tensors="pt", padding=True).to(device) | |
| print(f"\nTarget Image: {args.image_path}") | |
| print("-" * 30) | |
| with torch.no_grad(): | |
| image_embeds = model.get_image_features(**image_inputs) | |
| text_embeds = model.get_text_features(**text_inputs) | |
| image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) | |
| text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) | |
| similarities = image_embeds @ text_embeds.t() | |
| for text, similarity in zip(args.candidates, similarities[0]): | |
| sim_value = similarity.item() | |
| distance = 1.0 - sim_value | |
| print(f"{text} Similarity: {sim_value:.4f} | Distance: {distance:.4f}") | |
| if __name__ == "__main__": | |
| main() | |