| from transformers import PreTrainedModel |
| from transformers import AutoModelForCausalLM |
| from configuration_qwen2_5_vl import Qwen2_5_VLConfig |
| import torch |
|
|
| class QWenVLChatModel(PreTrainedModel): |
| config_class = Qwen2_5_VLConfig |
| base_model_prefix = "qwen2_5_vl" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| |
| self.model = AutoModelForCausalLM.from_config(config) |
|
|
| def forward(self, *args, **kwargs): |
| return self.model(*args, **kwargs) |
|
|
| def chat(self, tokenizer, query: str, image=None, history=None, **kwargs): |
| inputs = tokenizer(query, return_tensors="pt").to(self.device) |
| with torch.no_grad(): |
| outputs = self.model.generate(**inputs, max_new_tokens=512) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True), history |
|
|