yujingfeng commited on
Commit
e105c1a
·
verified ·
1 Parent(s): 8ccb0bf

Upload modeling_qwen2_5_vl.py

Browse files
Files changed (1) hide show
  1. modeling_qwen2_5_vl.py +16 -3
modeling_qwen2_5_vl.py CHANGED
@@ -1,9 +1,22 @@
1
- from transformers import AutoModelForCausalLM
 
 
2
  import torch
3
 
4
- class QWenVLChatModel(AutoModelForCausalLM):
 
 
 
 
 
 
 
 
 
 
 
5
  def chat(self, tokenizer, query: str, image=None, history=None, **kwargs):
6
  inputs = tokenizer(query, return_tensors="pt").to(self.device)
7
  with torch.no_grad():
8
- outputs = self.generate(**inputs, max_new_tokens=512)
9
  return tokenizer.decode(outputs[0], skip_special_tokens=True), history
 
1
+ from transformers import PreTrainedModel
2
+ from transformers import AutoModelForCausalLM # 用于模型内部实际加载
3
+ from configuration_qwen2_5_vl import Qwen2_5_VLConfig # 你自定义的配置类
4
  import torch
5
 
6
+ class QWenVLChatModel(PreTrainedModel):
7
+ config_class = Qwen2_5_VLConfig
8
+ base_model_prefix = "qwen2_5_vl"
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ # 用 AutoModelForCausalLM 来加载具体模型架构
13
+ self.model = AutoModelForCausalLM.from_config(config)
14
+
15
+ def forward(self, *args, **kwargs):
16
+ return self.model(*args, **kwargs)
17
+
18
  def chat(self, tokenizer, query: str, image=None, history=None, **kwargs):
19
  inputs = tokenizer(query, return_tensors="pt").to(self.device)
20
  with torch.no_grad():
21
+ outputs = self.model.generate(**inputs, max_new_tokens=512)
22
  return tokenizer.decode(outputs[0], skip_special_tokens=True), history