OmniSVG-3B / decoder.py
OmniSVG's picture
Update decoder.py
376372d verified
raw
history blame
2.39 kB
import torch.nn as nn
import torch
import yaml
import os
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
def load_config(config_path=None):
"""Load configuration from config.yaml"""
if config_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
possible_paths = [
os.path.join(current_dir, "config.yaml"),
os.path.join(current_dir, "..", "config.yaml"),
"config.yaml"
]
for path in possible_paths:
if os.path.exists(path):
config_path = path
break
if config_path is None:
raise FileNotFoundError("config.yaml not found")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
class SketchDecoder(nn.Module):
"""
Autoregressive generative model
"""
def __init__(self, config_path=None, model_path=None, **kwargs):
super().__init__()
config_data = load_config(config_path)
model_config = config_data.get('model', {})
huggingface_config = config_data.get('huggingface', {})
self.bos_token_id = model_config['bos_token_id']
self.eos_token_id = model_config['eos_token_id']
self.pad_token_id = model_config['pad_token_id']
self.vocab_size = model_config.get(
'vocab_size',
max(self.bos_token_id, self.eos_token_id, self.pad_token_id) + 1
)
if model_path is None:
model_path = huggingface_config['qwen_model']
config = AutoConfig.from_pretrained(
model_path,
vocab_size=self.vocab_size,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id
)
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
config=config,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="auto",
ignore_mismatched_sizes=True
)
self.transformer.resize_token_embeddings(self.vocab_size)
def forward(self, *args, **kwargs):
raise NotImplementedError("Forward pass not included in open-source version")