lily-math-rag / test_tokenizer.py
gbrabbit's picture
Auto commit at 07-2025-08 4:43:48
b9ecb65
raw
history blame
5.3 kB
import os
import traceback
from typing import Optional
from transformers import AutoTokenizer
import torch
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๋กœ๋“œ
try:
from dotenv import load_dotenv
load_dotenv()
print("โœ… .env ํŒŒ์ผ ๋กœ๋“œ๋จ")
except ImportError:
print("โš ๏ธ python-dotenv๊ฐ€ ์„ค์น˜๋˜์ง€ ์•Š์Œ")
HF_TOKEN = os.getenv("HF_TOKEN")
# ํ™˜๊ฒฝ ๊ฐ์ง€
IS_LOCAL = os.path.exists('../.env') or 'LOCAL_TEST' in os.environ
print(f"๐Ÿ” ํ™˜๊ฒฝ: {'๋กœ์ปฌ' if IS_LOCAL else '์„œ๋ฒ„'}")
# ํ™˜๊ฒฝ์— ๋”ฐ๋ฅธ ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ •
if IS_LOCAL:
# ๋กœ์ปฌ ๋ชจ๋ธ ๊ฒฝ๋กœ (hearth_llm_model ํด๋” ์‚ฌ์šฉ)
MODEL_PATH = "../lily_llm_core/models/kanana-1.5-v-3b-instruct"
print(f"๐Ÿ” ๋กœ์ปฌ ๋ชจ๋ธ ๊ฒฝ๋กœ: {MODEL_PATH}")
print(f"๐Ÿ” ๊ฒฝ๋กœ ์กด์žฌ: {os.path.exists(MODEL_PATH)}")
else:
# ์„œ๋ฒ„์—์„œ๋Š” Hugging Face ๋ชจ๋ธ ์‚ฌ์šฉ
MODEL_PATH = os.getenv("MODEL_NAME", "gbrabbit/lily-math-model")
print(f"๐Ÿ” ์„œ๋ฒ„ ๋ชจ๋ธ: {MODEL_PATH}")
print(f"๐Ÿ” ํ† ํฐ: {'โœ… ์„ค์ •๋จ' if HF_TOKEN else 'โŒ ์„ค์ •๋˜์ง€ ์•Š์Œ'}")
# ํ† ํฌ๋‚˜์ด์ € ํ…Œ์ŠคํŠธ
print("\n๐Ÿ”ง ํ† ํฌ๋‚˜์ด์ € ํ…Œ์ŠคํŠธ ์‹œ์ž‘...")
try:
print("๐Ÿ“ค ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์ค‘...")
print(f" MODEL_PATH: {MODEL_PATH}")
print(f" IS_LOCAL: {IS_LOCAL}")
print(f" trust_remote_code: True")
print(f" use_fast: False")
if IS_LOCAL:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
token=HF_TOKEN,
trust_remote_code=True,
)
print(f"โœ… ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ")
print(f" ํƒ€์ž…: {type(tokenizer)}")
print(f" ๊ฐ’: {tokenizer}")
print(f" hasattr('encode'): {hasattr(tokenizer, 'encode')}")
print(f" hasattr('__call__'): {hasattr(tokenizer, '__call__')}")
# ํ† ํฌ๋‚˜์ด์ € ํ…Œ์ŠคํŠธ
test_input = "์•ˆ๋…•ํ•˜์„ธ์š”"
print(f"\n๐Ÿ”ค ํ† ํฌ๋‚˜์ด์ € ํ…Œ์ŠคํŠธ: '{test_input}'")
test_tokens = tokenizer(test_input, return_tensors="pt")
print(f" โœ… ํ† ํฌ๋‚˜์ด์ € ํ˜ธ์ถœ ์„ฑ๊ณต")
print(f" input_ids shape: {test_tokens['input_ids'].shape}")
print(f" attention_mask shape: {test_tokens['attention_mask'].shape}")
# ๋””์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ
decoded = tokenizer.decode(test_tokens['input_ids'][0], skip_special_tokens=True)
print(f" ๋””์ฝ”๋”ฉ ๊ฒฐ๊ณผ: '{decoded}'")
except Exception as e:
print(f"โŒ ํ† ํฌ๋‚˜์ด์ € ํ…Œ์ŠคํŠธ ์‹คํŒจ: {e}")
print(f" ์˜ค๋ฅ˜ ํƒ€์ž…: {type(e).__name__}")
traceback.print_exc()
# ๋ชจ๋ธ ํ…Œ์ŠคํŠธ
print("\n๐Ÿ”ง ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ์‹œ์ž‘...")
try:
print("๐Ÿ“ค ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
from modeling import KananaVForConditionalGeneration
if IS_LOCAL:
model = KananaVForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map=None,
low_cpu_mem_usage=True
)
else:
model = KananaVForConditionalGeneration.from_pretrained(
MODEL_PATH,
token=HF_TOKEN,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map=None,
low_cpu_mem_usage=True
)
print(f"โœ… ๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
# print(f" ํƒ€์ž…: {type(model)}")
# print(f" ๋””๋ฐ”์ด์Šค: {next(model.parameters()).device}")
# ๋ชจ๋ธ ํ…Œ์ŠคํŠธ
test_input = "์•ˆ๋…•ํ•˜์„ธ์š”"
formatted_prompt = f"<|im_start|>user\n{test_input}<|im_end|>\n<|im_start|>assistant\n"
max_length: Optional[int] = None
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
print(f"\n๐Ÿค– ๋ชจ๋ธ ์ถ”๋ก  ํ…Œ์ŠคํŠธ: '{test_input}'")
# Kanana์šฉ ์ƒ์„ฑ ์„ค์ •
max_new_tokens = max_length or 100
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=max_new_tokens,
repetition_penalty=1.1,
no_repeat_ngram_size=2,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True
)
print(f" โœ… ๋ชจ๋ธ ํ˜ธ์ถœ ์„ฑ๊ณต")
print(f" outputs ํƒ€์ž…: {type(outputs)}")
print(f" outputs shape: {outputs.shape}")
# ๋””์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ
# model.generate()์˜ ์ถœ๋ ฅ์€ ์ „์ฒด ์‹œํ€€์Šค์ด๋ฏ€๋กœ ๋ฐ”๋กœ ๋””์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค.
# outputs[0]์€ ๋ฐฐ์น˜ ์ค‘ ์ฒซ ๋ฒˆ์งธ ๊ฒฐ๊ณผ๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‘๋‹ต์—์„œ ์ œ๊ฑฐ (์„ ํƒ์‚ฌํ•ญ)
assistant_response = response.split("<|im_start|>assistant\n")[-1]
print(f" ์ƒ์„ฑ๋œ ์ „์ฒด ํ…์ŠคํŠธ: '{response}'")
print(f" ์–ด์‹œ์Šคํ„ดํŠธ ์‘๋‹ต: '{assistant_response.strip()}'")
except Exception as e:
print(f"โŒ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ์‹คํŒจ: {e}")
print(f" ์˜ค๋ฅ˜ ํƒ€์ž…: {type(e).__name__}")
traceback.print_exc()
print("\nโœ… ํ…Œ์ŠคํŠธ ์™„๋ฃŒ!")