yuxianglai117 commited on
Commit
7b0bae8
·
verified ·
1 Parent(s): 88d5eeb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -0
README.md CHANGED
@@ -66,6 +66,63 @@ processor = AutoProcessor.from_pretrained(MODEL_PATH)
66
  ]
67
  ```
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  ## Acknowledgements
70
 
71
  We thank the authors of **OmniMedVQA** and **R1-V** for their open-source contributions.
 
66
  ]
67
  ```
68
 
69
+ ### Inference
70
+ ```python
71
+ with open(PROMPT_PATH, "r", encoding="utf-8") as f:
72
+ data = json.load(f)
73
+
74
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final choice (A, B, C, D ...) in <answer> </answer> tags."
75
+
76
+ messages = []
77
+
78
+ for i in data:
79
+ message = [{
80
+ "role": "user",
81
+ "content": [
82
+ {
83
+ "type": "image",
84
+ "image": f"file://{i['image']}"
85
+ },
86
+ {
87
+ "type": "text",
88
+ "text": QUESTION_TEMPLATE.format(Question=i['problem'])
89
+ }
90
+ ]
91
+ }]
92
+ messages.append(message)
93
+
94
+
95
+ for i in tqdm(range(0, len(messages), BSZ)):
96
+ batch_messages = messages[i:i + BSZ]
97
+
98
+ # Preparation for inference
99
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
100
+
101
+ image_inputs, video_inputs = process_vision_info(batch_messages)
102
+ inputs = processor(
103
+ text=text,
104
+ images=image_inputs,
105
+ videos=video_inputs,
106
+ padding=True,
107
+ return_tensors="pt",
108
+ )
109
+ inputs = inputs.to("cuda")
110
+
111
+ # Inference: Generation of the output
112
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
113
+
114
+ generated_ids_trimmed = [
115
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
116
+ ]
117
+ batch_output_text = processor.batch_decode(
118
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
119
+ )
120
+
121
+ all_outputs.extend(batch_output_text)
122
+ print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
123
+
124
+ ```
125
+
126
  ## Acknowledgements
127
 
128
  We thank the authors of **OmniMedVQA** and **R1-V** for their open-source contributions.