VOIDER commited on
Commit
1c355ce
·
verified ·
1 Parent(s): 6d24361

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -186
app.py CHANGED
@@ -1,200 +1,299 @@
1
- import os
2
- import sys
3
- import subprocess
4
-
5
- # --- ПРОВЕРКА И УСТАНОВКА БИБЛИОТЕКИ ---
6
- try:
7
- from llama_cpp import Llama, LlamaChatCompletionHandler
8
- print("Библиотека llama-cpp-python найдена.")
9
- except ImportError:
10
- print("Установка llama-cpp-python (CPU)...")
11
- # Принудительно ставим 0.3.16 или новее с поддержкой CPU
12
- subprocess.check_call([
13
- sys.executable, "-m", "pip", "install",
14
- "llama-cpp-python>=0.3.16",
15
- "--extra-index-url", "https://abetlen.github.io/llama-cpp-python/whl/cpu"
16
- ])
17
- from llama_cpp import Llama, LlamaChatCompletionHandler
18
-
19
  import gradio as gr
20
- from huggingface_hub import hf_hub_download
21
- import base64
22
- import io
 
23
  import re
24
- from PIL import Image
25
-
26
- # Конфигурация
27
- REPO_ID = "mradermacher/VisualQuality-R1-7B-GGUF"
28
- MODEL_FILENAME = "VisualQuality-R1-7B.Q8_0.gguf"
29
-
30
- # === ГЛАВНЫЙ ФИКС: СВОЙ ОБРАБОТЧИК ДЛЯ QWEN2-VL ===
31
- # Мы не зависим от встроенных классов, а пишем свой.
32
- class CustomQwen2VLHandler(LlamaChatCompletionHandler):
33
- def __init__(self, clip_model_path=None, verbose=False):
34
- self.clip_model_path = clip_model_path
35
- self.verbose = verbose
36
-
37
- def __call__(self, llama: Llama, messages, functions=None, function_call=None, tools=None, tool_choice=None, **kwargs):
38
- # 1. Формируем промпт вручную с правильными тегами
39
- prompt = ""
40
- images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- for message in messages:
43
- role = message["role"]
44
- content = message["content"]
45
-
46
- # Начало сообщения
47
- prompt += f"<|im_start|>{role}\n"
48
-
49
- if isinstance(content, str):
50
- prompt += content
51
- elif isinstance(content, list):
52
- for part in content:
53
- if part["type"] == "text":
54
- prompt += part["text"]
55
- elif part["type"] == "image_url":
56
- # Теги для Qwen2-VL: Vision Start -> Pad -> Vision End
57
- prompt += "<|vision_start|><|image_pad|><|vision_end|>"
58
-
59
- # Извлекаем байты из base64 для передачи в C++ слой
60
- try:
61
- image_url = part["image_url"]["url"]
62
- if "base64," in image_url:
63
- base64_data = image_url.split("base64,")[1]
64
- image_bytes = base64.b64decode(base64_data)
65
- images.append(image_bytes)
66
- except Exception as e:
67
- print(f"Ошибка декодирования картинки: {e}")
68
-
69
- # Конец сообщения
70
- prompt += "<|im_end|>\n"
71
 
72
- # Добавляем триггер для ответа ассистента
73
- prompt += "<|im_start|>assistant\n"
 
 
74
 
75
- if self.verbose:
76
- print(f"=== SENDED PROMPT ({len(prompt)} chars) ===")
77
- print(prompt[:200] + "..." if len(prompt) > 200 else prompt)
78
- print(f"=== IMAGES: {len(images)} ===")
79
-
80
- # Возвращаем кортеж (prompt, images), который понимает llama.cpp
81
- return prompt, images
82
-
83
- llm = None
84
-
85
- def load_model():
86
- global llm
87
- if llm is None:
88
- print(f"Загрузка модели {MODEL_FILENAME}...")
89
- try:
90
- model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
91
-
92
- # Инициализируем НАШ кастомный хендлер
93
- # clip_model_path указываем на тот же файл (так как это GGUF all-in-one)
94
- chat_handler = CustomQwen2VLHandler(clip_model_path=model_path, verbose=True)
95
-
96
- llm = Llama(
97
- model_path=model_path,
98
- n_ctx=8192, # Контекст (картинки большие, нужно место)
99
- n_gpu_layers=0, # CPU
100
- verbose=True,
101
- chat_handler=chat_handler, # <-- ВАЖНО: Используем наш класс
102
- n_batch=512,
103
- logits_all=True
104
- )
105
- print("Модель успешно загружена с CustomQwen2VLHandler!")
106
- except Exception as e:
107
- print(f"Ошибка загрузки: {e}")
108
- raise e
109
- return llm
110
-
111
- def process_image(image):
112
- # Ресайз до 1024px макс, чтобы не перегружать CPU память и контекст
113
- max_dim = 1024
114
- if max(image.size) > max_dim:
115
- image.thumbnail((max_dim, max_dim), Image.Resampling.LANCZOS)
116
-
117
- buffered = io.BytesIO()
118
- image = image.convert("RGB")
119
- image.save(buffered, format="JPEG", quality=90)
120
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
121
-
122
- def evaluate_image(image, progress=gr.Progress()):
123
- if image is None:
124
- return "Пожалуйста, загрузите изображение.", ""
125
-
126
- try:
127
- progress(0.1, desc="Загрузка модели...")
128
- model = load_model()
129
 
130
- progress(0.2, desc="Обработка...")
131
- base64_img = process_image(image)
132
- img_url = f"data:image/jpeg;base64,{base64_img}"
133
-
134
- system_prompt = "You are doing the image quality assessment task."
135
- user_prompt = (
136
- "What is your overall rating on the quality of this picture? "
137
- "The rating should be a float between 1 and 5, rounded to two decimal places, "
138
- "with 1 representing very poor quality and 5 representing excellent quality. "
139
- "Please only output the final answer with only one score in <answer> </answer> tags."
140
- )
 
 
 
141
 
142
- messages = [
143
- {"role": "system", "content": system_prompt},
144
- {
145
- "role": "user",
146
- "content": [
147
- {"type": "image_url", "image_url": {"url": img_url}},
148
- {"type": "text", "text": user_prompt}
149
- ]
150
- }
151
- ]
152
-
153
- full_response = ""
154
- print("Начинаю генерацию...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- # Запуск стриминга
157
- stream = model.create_chat_completion(
158
- messages=messages,
159
- max_tokens=1024,
160
- temperature=0.6,
161
- stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  )
163
-
164
- for chunk in stream:
165
- if "choices" in chunk:
166
- delta = chunk["choices"][0]["delta"]
167
- if "content" in delta and delta["content"]:
168
- content = delta["content"]
169
- full_response += content
170
- yield full_response, "Думаю..."
171
-
172
- # Поиск оценки
173
- score_match = re.search(r'<answer>\s*([\d\.]+)\s*</answer>', full_response)
174
- final_score = score_match.group(1) if score_match else "Оценка не найдена"
175
 
176
- yield full_response, final_score
177
-
178
- except Exception as e:
179
- err_msg = f"Произошла ошибка: {str(e)}"
180
- print(err_msg)
181
- yield err_msg, "Error"
182
-
183
- # Интерфейс
184
- with gr.Blocks(title="VisualQuality-R1 (Custom Handler)") as demo:
185
- gr.Markdown("# 👁️ VisualQuality-R1 (Qwen2-VL)")
186
- gr.Markdown("Оценка качества изображений на CPU с кастомным обработчиком.")
187
-
188
- with gr.Row():
189
- with gr.Column():
190
- input_img = gr.Image(type="pil", label="Изображение")
191
- run_btn = gr.Button("Оценить", variant="primary")
192
 
193
- with gr.Column():
194
- output_score = gr.Label(label="Оценка")
195
- output_text = gr.Textbox(label="CoT (Рассуждения)", lines=15)
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- run_btn.click(evaluate_image, inputs=[input_img], outputs=[output_text, output_score])
198
 
199
  if __name__ == "__main__":
200
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig, TextIteratorStreamer
4
+ from qwen_vl_utils import process_vision_info
5
+ from threading import Thread
6
  import re
7
+ import random
8
+ import spaces
9
+
10
+ # Константы
11
+ MODEL_PATH = "TianheWu/VisualQuality-R1-7B"
12
+
13
+ # Промпты
14
+ PROMPT = (
15
+ "You are doing the image quality assessment task. Here is the question: "
16
+ "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
17
+ "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
18
+ )
19
+
20
+ QUESTION_TEMPLATE_THINKING = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
21
+ QUESTION_TEMPLATE_NO_THINKING = "{Question} Please only output the final answer with only one score in <answer> </answer> tags."
22
+
23
+ # Конфигурация 8-bit квантизации
24
+ quantization_config = BitsAndBytesConfig(
25
+ load_in_8bit=True,
26
+ llm_int8_threshold=6.0,
27
+ llm_int8_has_fp16_weight=False,
28
+ )
29
+
30
+ print("Loading model...")
31
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
32
+ MODEL_PATH,
33
+ quantization_config=quantization_config,
34
+ device_map="auto",
35
+ trust_remote_code=True,
36
+ torch_dtype=torch.float16,
37
+ )
38
+ model.eval()
39
+
40
+ processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
41
+ processor.tokenizer.padding_side = "left"
42
+ print("Model loaded successfully!")
43
+
44
+
45
+ def extract_score(text):
46
+ """Извлечение оценки из текста"""
47
+ try:
48
+ model_output_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
49
+ if model_output_matches:
50
+ model_answer = model_output_matches[-1].strip()
51
+ else:
52
+ model_answer = text.strip()
53
+ score_match = re.search(r'\d+(\.\d+)?', model_answer)
54
+ if score_match:
55
+ score = float(score_match.group())
56
+ return min(max(score, 1.0), 5.0) # Ограничение от 1 до 5
57
+ except Exception as e:
58
+ print(f"Error extracting score: {e}")
59
+ return None
60
+
61
+
62
+ def extract_thinking(text):
63
+ """Извлечение процесса мышления из текста"""
64
+ thinking_matches = re.findall(r'<think>(.*?)</think>', text, re.DOTALL)
65
+ if thinking_matches:
66
+ return thinking_matches[-1].strip()
67
+ return None
68
+
69
+
70
+ @spaces.GPU(duration=120)
71
+ def score_image_streaming(image, use_thinking=True):
72
+ """Оценка качества изображения со стримингом"""
73
+ if image is None:
74
+ yield "❌ Please upload an image first.", "", ""
75
+ return
76
+
77
+ # Выбор шаблона
78
+ if use_thinking:
79
+ question_template = QUESTION_TEMPLATE_THINKING
80
+ else:
81
+ question_template = QUESTION_TEMPLATE_NO_THINKING
82
+
83
+ # Формирование сообщения
84
+ message = [
85
+ {
86
+ "role": "user",
87
+ "content": [
88
+ {'type': 'image', 'image': image},
89
+ {"type": "text", "text": question_template.format(Question=PROMPT)}
90
+ ],
91
+ }
92
+ ]
93
+
94
+ batch_messages = [message]
95
+
96
+ # Подготовка входных данных
97
+ text = [processor.apply_chat_template(
98
+ msg, tokenize=False, add_generation_prompt=True, add_vision_id=True
99
+ ) for msg in batch_messages]
100
+
101
+ image_inputs, video_inputs = process_vision_info(batch_messages)
102
+
103
+ inputs = processor(
104
+ text=text,
105
+ images=image_inputs,
106
+ videos=video_inputs,
107
+ padding=True,
108
+ return_tensors="pt",
109
+ )
110
+ inputs = inputs.to(model.device)
111
+
112
+ # Настройка стриминга
113
+ streamer = TextIteratorStreamer(
114
+ processor.tokenizer,
115
+ skip_prompt=True,
116
+ skip_special_tokens=True
117
+ )
118
+
119
+ generation_kwargs = dict(
120
+ **inputs,
121
+ streamer=streamer,
122
+ max_new_tokens=2048 if use_thinking else 256,
123
+ do_sample=True,
124
+ top_k=50,
125
+ top_p=0.95,
126
+ temperature=0.7,
127
+ use_cache=True,
128
+ )
129
+
130
+ # Запуск генерации в отдельном потоке
131
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
132
+ thread.start()
133
+
134
+ # Стриминг вывода
135
+ generated_text = ""
136
+ current_thinking = ""
137
+ current_score = ""
138
+
139
+ for new_text in streamer:
140
+ generated_text += new_text
141
 
142
+ # Извлечение мышления (если есть)
143
+ thinking = extract_thinking(generated_text)
144
+ if thinking:
145
+ current_thinking = thinking
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ # Извлечение оценки
148
+ score = extract_score(generated_text)
149
+ if score is not None:
150
+ current_score = f"⭐ **Quality Score: {score:.2f} / 5.00**"
151
 
152
+ # Форматирование вывода
153
+ display_text = generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ yield display_text, current_thinking, current_score
156
+
157
+ thread.join()
158
+
159
+ # Финальное извлечение
160
+ final_score = extract_score(generated_text)
161
+ final_thinking = extract_thinking(generated_text) if use_thinking else ""
162
+
163
+ if final_score is not None:
164
+ score_display = f"⭐ **Quality Score: {final_score:.2f} / 5.00**\n\n📊 **For Leaderboard:** `{final_score:.2f}`"
165
+ else:
166
+ score_display = "❌ Could not extract score. Please try again."
167
+
168
+ yield generated_text, final_thinking or "", score_display
169
 
170
+
171
+ def create_interface():
172
+ """Создание интерфейса Gradio"""
173
+
174
+ with gr.Blocks(
175
+ title="VisualQuality-R1: Image Quality Assessment",
176
+ theme=gr.themes.Soft(),
177
+ css="""
178
+ .score-box {
179
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
180
+ border-radius: 10px;
181
+ padding: 20px;
182
+ color: white;
183
+ text-align: center;
184
+ font-size: 1.2em;
185
+ }
186
+ .thinking-box {
187
+ background-color: #f0f4f8;
188
+ border-left: 4px solid #667eea;
189
+ padding: 15px;
190
+ border-radius: 5px;
191
+ font-style: italic;
192
+ }
193
+ """
194
+ ) as demo:
195
+
196
+ gr.Markdown("""
197
+ # 🎨 VisualQuality-R1: Image Quality Assessment
198
+
199
+ **Reasoning-Induced Image Quality Assessment via Reinforcement Learning to Rank**
200
 
201
+ Upload an image to get a quality score (1-5) with detailed reasoning.
202
+
203
+ [![Paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2505.14460)
204
+ [![Model](https://img.shields.io/badge/🤗-Model-yellow)](https://huggingface.co/TianheWu/VisualQuality-R1-7B)
205
+ """)
206
+
207
+ with gr.Row():
208
+ with gr.Column(scale=1):
209
+ image_input = gr.Image(
210
+ label="📷 Upload Image",
211
+ type="pil",
212
+ height=400
213
+ )
214
+
215
+ thinking_checkbox = gr.Checkbox(
216
+ label="🧠 Enable Thinking Mode (detailed reasoning)",
217
+ value=True
218
+ )
219
+
220
+ submit_btn = gr.Button(
221
+ "🔍 Analyze Image Quality",
222
+ variant="primary",
223
+ size="lg"
224
+ )
225
+
226
+ gr.Markdown("""
227
+ ### 📖 Instructions:
228
+ 1. Upload an image
229
+ 2. Enable/disable thinking mode
230
+ 3. Click "Analyze Image Quality"
231
+ 4. Wait for the score and reasoning
232
+
233
+ ### 📊 Score Scale:
234
+ - **1.0**: Very poor quality
235
+ - **2.0**: Poor quality
236
+ - **3.0**: Fair quality
237
+ - **4.0**: Good quality
238
+ - **5.0**: Excellent quality
239
+ """)
240
+
241
+ with gr.Column(scale=1):
242
+ score_output = gr.Markdown(
243
+ label="Quality Score",
244
+ value="*Upload an image to see the score*"
245
+ )
246
+
247
+ thinking_output = gr.Textbox(
248
+ label="🧠 Thinking Process",
249
+ lines=8,
250
+ max_lines=15,
251
+ placeholder="Reasoning will appear here when thinking mode is enabled...",
252
+ interactive=False
253
+ )
254
+
255
+ raw_output = gr.Textbox(
256
+ label="📝 Full Model Output",
257
+ lines=10,
258
+ max_lines=20,
259
+ placeholder="Full model response will appear here...",
260
+ interactive=False
261
+ )
262
+
263
+ # Примеры
264
+ gr.Markdown("### 📸 Example Images")
265
+ gr.Examples(
266
+ examples=[
267
+ ["https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/PNG_transparency_demonstration_1.png/300px-PNG_transparency_demonstration_1.png"],
268
+ ],
269
+ inputs=[image_input],
270
+ label="Click to try"
271
  )
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ # Обработка события
274
+ submit_btn.click(
275
+ fn=score_image_streaming,
276
+ inputs=[image_input, thinking_checkbox],
277
+ outputs=[raw_output, thinking_output, score_output],
278
+ )
 
 
 
 
 
 
 
 
 
 
279
 
280
+ gr.Markdown("""
281
+ ---
282
+ ### 📚 Citation
283
+ ```bibtex
284
+ @article{wu2025visualquality,
285
+ title={{VisualQuality-R1}: Reasoning-Induced Image Quality Assessment via Reinforcement Learning to Rank},
286
+ author={Wu, Tianhe and Zou, Jian and Liang, Jie and Zhang, Lei and Ma, Kede},
287
+ journal={arXiv preprint arXiv:2505.14460},
288
+ year={2025}
289
+ }
290
+ ```
291
+ """)
292
+
293
+ return demo
294
 
 
295
 
296
  if __name__ == "__main__":
297
+ demo = create_interface()
298
+ demo.queue(max_size=10)
299
+ demo.launch()