TianheWu commited on
Commit
13b207d
·
verified ·
1 Parent(s): 6f54b0a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +97 -2
README.md CHANGED
@@ -11,5 +11,100 @@ tags:
11
  - Reasoning-Induced
12
  ---
13
 
14
- ## ImageQuality-R1-v1
15
- This is a demo version of ImageQuality-R1. Our model is trained on the combination of KADID-10K, TID2013, and KONIQ-10K.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  - Reasoning-Induced
12
  ---
13
 
14
+ # ImageQuality-R1-v1
15
+ This is a demo version of ImageQuality-R1 which is trained on the combination of KADID-10K, TID2013, and KONIQ-10K.\
16
+ The base model of ImageQuality-R1 is Qwen2.5-VL-7B-Instruct.
17
+
18
+ ## Quick Start
19
+ ```
20
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
21
+ from qwen_vl_utils import process_vision_info
22
+
23
+ import json
24
+ import numpy as np
25
+ import torch
26
+ import random
27
+ import re
28
+ import os
29
+
30
+
31
+ def score_image(model_path, image_path):
32
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
33
+ model_path,
34
+ torch_dtype=torch.bfloat16,
35
+ attn_implementation="flash_attention_2",
36
+ device_map=device,
37
+ )
38
+ processor = AutoProcessor.from_pretrained(MODEL_PATH)
39
+ processor.tokenizer.padding_side = "left"
40
+
41
+ PROMPT = (
42
+ "You are doing the image quality assessment task. Here is the question: "
43
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
44
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
45
+ )
46
+
47
+ x = {
48
+ "image": [image_path],
49
+ "question": PROMPT,
50
+ }
51
+
52
+ QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
53
+ message = [
54
+ {
55
+ "role": "user",
56
+ "content": [
57
+ *({'type': 'image', 'image': img_path} for img_path in x['image']),
58
+ {"type": "text", "text": QUESTION_TEMPLATE.format(Question=x['question'])}
59
+ ],
60
+ }
61
+ ]
62
+
63
+ batch_messages = [message]
64
+
65
+ # Preparation for inference
66
+ text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True, add_vision_id=True) for msg in batch_messages]
67
+ image_inputs, video_inputs = process_vision_info(batch_messages)
68
+ inputs = processor(
69
+ text=text,
70
+ images=image_inputs,
71
+ videos=video_inputs,
72
+ padding=True,
73
+ return_tensors="pt",
74
+ )
75
+ inputs = inputs.to(device)
76
+
77
+ # Inference: Generation of the output
78
+ generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=True)
79
+ generated_ids_trimmed = [
80
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
81
+ ]
82
+ batch_output_text = processor.batch_decode(
83
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
84
+ )
85
+
86
+ reasoning = re.findall(r'<think>(.*?)</think>', batch_output_text[0], re.DOTALL)
87
+ reasoning = reasoning[-1].strip()
88
+
89
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', batch_output_text[0], re.DOTALL)
90
+ model_answer = model_output_matches[-1].strip()
91
+ score = float(re.search(r'\d+(\.\d+)?', model_answer).group())
92
+
93
+ return reasoning, score
94
+
95
+
96
+ random.seed(42)
97
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
98
+
99
+ ### Modify here
100
+ MODEL_PATH = ""
101
+ image_path = ""
102
+
103
+ reasoning, score = score_image(
104
+ model_path=MODEL_PATH,
105
+ image_path=image_path
106
+ )
107
+
108
+ print(reasoning)
109
+ print(score)
110
+ ```