Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,389 Bytes
0a475cf 376372d 4f02170 3d22eb7 376372d 0a475cf ac3151e 376372d ac3151e 376372d ac3151e 18e39e9 ac3151e 18e39e9 ac3151e 18e39e9 ac3151e 0a475cf ac3151e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch.nn as nn
import torch
import yaml
import os
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
def load_config(config_path=None):
"""Load configuration from config.yaml"""
if config_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
possible_paths = [
os.path.join(current_dir, "config.yaml"),
os.path.join(current_dir, "..", "config.yaml"),
"config.yaml"
]
for path in possible_paths:
if os.path.exists(path):
config_path = path
break
if config_path is None:
raise FileNotFoundError("config.yaml not found")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
class SketchDecoder(nn.Module):
"""
Autoregressive generative model
"""
def __init__(self, config_path=None, model_path=None, **kwargs):
super().__init__()
config_data = load_config(config_path)
model_config = config_data.get('model', {})
huggingface_config = config_data.get('huggingface', {})
self.bos_token_id = model_config['bos_token_id']
self.eos_token_id = model_config['eos_token_id']
self.pad_token_id = model_config['pad_token_id']
self.vocab_size = model_config.get(
'vocab_size',
max(self.bos_token_id, self.eos_token_id, self.pad_token_id) + 1
)
if model_path is None:
model_path = huggingface_config['qwen_model']
config = AutoConfig.from_pretrained(
model_path,
vocab_size=self.vocab_size,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id
)
self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
config=config,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="auto",
ignore_mismatched_sizes=True
)
self.transformer.resize_token_embeddings(self.vocab_size)
def forward(self, *args, **kwargs):
raise NotImplementedError("Forward pass not included in open-source version") |