jymmmmm commited on
Commit
24b2ebc
·
verified ·
1 Parent(s): 9d3c0d9

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +102 -0
README.md CHANGED
@@ -18,6 +18,108 @@ MAmmoTH-VL2, the model trained with VisualWebInstruct.
18
  [Paper](https://arxiv.org/abs/2503.10582)|
19
  [Website](https://tiger-ai-lab.github.io/VisualWebInstruct/)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Citation
22
  ```
23
  @article{visualwebinstruct,
 
18
  [Paper](https://arxiv.org/abs/2503.10582)|
19
  [Website](https://tiger-ai-lab.github.io/VisualWebInstruct/)
20
 
21
+ # Example Usage
22
+ To perform inference using MAmmoTH-VL2, you can use the following code snippet:
23
+ ```python
24
+ # pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
25
+
26
+ from llava.model.builder import load_pretrained_model
27
+ from llava.mm_utils import process_images
28
+ from llava.constants import DEFAULT_IMAGE_TOKEN
29
+ from llava.conversation import conv_templates
30
+
31
+ from PIL import Image
32
+ import requests
33
+ import copy
34
+ import torch
35
+
36
+ # Load MAmmoTH-VL2 model
37
+ pretrained = "TIGER-Lab/MAmmoTH-VL2"
38
+ model_name = "llava_qwen"
39
+ device = "cuda:3" # Specify a single GPU
40
+ device_map = {"": device}
41
+
42
+ # Load model
43
+ tokenizer, model, image_processor, max_length = load_pretrained_model(
44
+ pretrained,
45
+ None,
46
+ model_name,
47
+ device_map=device_map,
48
+ multimodal=True
49
+ )
50
+ model.eval()
51
+ model = model.to(device)
52
+
53
+ # Load image
54
+ image_url = "https://raw.githubusercontent.com/jymmmmm/VISUALWEBINSTRUCT/main/image.png"
55
+ image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
56
+ images = [image]
57
+ image_sizes = [[image.size[0], image.size[1]]]
58
+
59
+ # Prepare prompt
60
+ prompt = "In the picture shown below, prove ΔWXY and ΔZWY are similar. Please conclude your answer as Answer: xxx at the end if possible."
61
+
62
+ # Set up conversation template
63
+ try:
64
+ conv_template = "qwen_2_5"
65
+ conv = copy.deepcopy(conv_templates[conv_template])
66
+ except KeyError:
67
+ available_templates = list(conv_templates.keys())
68
+ for template_name in available_templates:
69
+ if 'qwen' in template_name.lower():
70
+ conv_template = template_name
71
+ break
72
+ else:
73
+ conv_template = available_templates[0]
74
+ conv = copy.deepcopy(conv_templates[conv_template])
75
+
76
+ # Add question with image
77
+ question = DEFAULT_IMAGE_TOKEN + "\n" + prompt
78
+ conv.append_message(conv.roles[0], question)
79
+ conv.append_message(conv.roles[1], None)
80
+ prompt_question = conv.get_prompt()
81
+
82
+ # Prepare model inputs
83
+ inputs = tokenizer(
84
+ prompt_question,
85
+ return_tensors="pt",
86
+ padding=True,
87
+ truncation=True,
88
+ max_length=max_length
89
+ )
90
+ input_ids = inputs.input_ids.to(device)
91
+ attention_mask = inputs.attention_mask.to(device)
92
+
93
+ # Process image
94
+ image_tensor = process_images(images, image_processor, model.config)
95
+ if isinstance(image_tensor, list):
96
+ image_tensor = [img.to(dtype=torch.float16, device=device) for img in image_tensor]
97
+ else:
98
+ image_tensor = image_tensor.to(dtype=torch.float16, device=device)
99
+
100
+ # Generate response
101
+ with torch.no_grad():
102
+ outputs = model.generate(
103
+ input_ids,
104
+ attention_mask=attention_mask,
105
+ images=image_tensor,
106
+ image_sizes=image_sizes,
107
+ do_sample=False,
108
+ temperature=0,
109
+ max_new_tokens=512,
110
+ )
111
+
112
+ # Decode response
113
+ input_token_len = input_ids.shape[1]
114
+ response = tokenizer.batch_decode(outputs[:, input_token_len:], skip_special_tokens=True)[0]
115
+ print("Response:", response)
116
+
117
+ ```
118
+
119
+
120
+
121
+
122
+
123
  # Citation
124
  ```
125
  @article{visualwebinstruct,