VOIDER's picture
Update app.py
c37360a verified
raw
history blame
9.66 kB
import gradio as gr
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig, TextIteratorStreamer
from qwen_vl_utils import process_vision_info
from threading import Thread
import re
import spaces
# Константы
MODEL_PATH = "TianheWu/VisualQuality-R1-7B"
# Промпты
PROMPT = (
"You are doing the image quality assessment task. Here is the question: "
"What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, "
"rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality."
)
QUESTION_TEMPLATE_THINKING = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags."
QUESTION_TEMPLATE_NO_THINKING = "{Question} Please only output the final answer with only one score in <answer> </answer> tags."
# Глобальные переменные для модели
model = None
processor = None
def load_model():
"""Загрузка модели с 8-bit квантизацией"""
global model, processor
if model is not None:
return
print("Loading model...")
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
)
model.eval()
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
processor.tokenizer.padding_side = "left"
print("Model loaded successfully!")
def extract_score(text):
"""Извлечение оценки из текста"""
try:
model_output_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL)
if model_output_matches:
model_answer = model_output_matches[-1].strip()
else:
model_answer = text.strip()
score_match = re.search(r'\d+(\.\d+)?', model_answer)
if score_match:
score = float(score_match.group())
return min(max(score, 1.0), 5.0)
except Exception as e:
print(f"Error extracting score: {e}")
return None
def extract_thinking(text):
"""Извлечение процесса мышления из текста"""
thinking_matches = re.findall(r'<think>(.*?)</think>', text, re.DOTALL)
if thinking_matches:
return thinking_matches[-1].strip()
return None
@spaces.GPU(duration=180)
def score_image_streaming(image, use_thinking=True):
"""Оценка качества изображения со стримингом"""
global model, processor
# Загрузка модели при первом вызове
load_model()
if image is None:
yield "❌ Please upload an image first.", "", ""
return
# Выбор шаблона
if use_thinking:
question_template = QUESTION_TEMPLATE_THINKING
else:
question_template = QUESTION_TEMPLATE_NO_THINKING
# Формирование сообщения
message = [
{
"role": "user",
"content": [
{'type': 'image', 'image': image},
{"type": "text", "text": question_template.format(Question=PROMPT)}
],
}
]
batch_messages = [message]
# Подготовка входных данных
text = [processor.apply_chat_template(
msg, tokenize=False, add_generation_prompt=True, add_vision_id=True
) for msg in batch_messages]
image_inputs, video_inputs = process_vision_info(batch_messages)
inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
# Настройка стриминга
streamer = TextIteratorStreamer(
processor.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=2048 if use_thinking else 256,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
use_cache=True,
)
# Запуск генерации в отдельном потоке
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Стриминг вывода
generated_text = ""
current_thinking = ""
current_score = "*Analyzing...*"
for new_text in streamer:
generated_text += new_text
# Извлечение мышления (если есть)
thinking = extract_thinking(generated_text)
if thinking:
current_thinking = thinking
# Извлечение оценки
score = extract_score(generated_text)
if score is not None:
current_score = f"⭐ **Quality Score: {score:.2f} / 5.00**"
yield generated_text, current_thinking, current_score
thread.join()
# Финальное извлечение
final_score = extract_score(generated_text)
final_thinking = extract_thinking(generated_text) if use_thinking else ""
if final_score is not None:
score_display = f"⭐ **Quality Score: {final_score:.2f} / 5.00**\n\n📊 **For Leaderboard:** `{final_score:.2f}`"
else:
score_display = "❌ Could not extract score. Please try again."
yield generated_text, final_thinking or "", score_display
def create_interface():
"""Создание интерфейса Gradio"""
# Убрали theme из gr.Blocks() - теперь он передаётся в launch()
with gr.Blocks(
title="VisualQuality-R1: Image Quality Assessment",
) as demo:
gr.Markdown("""
# 🎨 VisualQuality-R1: Image Quality Assessment
**Reasoning-Induced Image Quality Assessment via Reinforcement Learning to Rank**
Upload an image to get a quality score (1-5) with detailed reasoning.
[![Paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2505.14460)
[![Model](https://img.shields.io/badge/🤗-Model-yellow)](https://huggingface.co/TianheWu/VisualQuality-R1-7B)
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="📷 Upload Image",
type="pil",
height=400
)
thinking_checkbox = gr.Checkbox(
label="🧠 Enable Thinking Mode (detailed reasoning)",
value=True
)
submit_btn = gr.Button(
"🔍 Analyze Image Quality",
variant="primary",
size="lg"
)
gr.Markdown("""
### 📖 Instructions:
1. Upload an image
2. Enable/disable thinking mode
3. Click "Analyze Image Quality"
4. Wait for the score and reasoning
### 📊 Score Scale:
- **1.0**: Very poor quality
- **2.0**: Poor quality
- **3.0**: Fair quality
- **4.0**: Good quality
- **5.0**: Excellent quality
""")
with gr.Column(scale=1):
score_output = gr.Markdown(
label="Quality Score",
value="*Upload an image to see the score*"
)
thinking_output = gr.Textbox(
label="🧠 Thinking Process",
lines=8,
max_lines=15,
placeholder="Reasoning will appear here when thinking mode is enabled...",
interactive=False
)
raw_output = gr.Textbox(
label="📝 Full Model Output",
lines=10,
max_lines=20,
placeholder="Full model response will appear here...",
interactive=False
)
# Обработка события
submit_btn.click(
fn=score_image_streaming,
inputs=[image_input, thinking_checkbox],
outputs=[raw_output, thinking_output, score_output],
)
gr.Markdown("""
---
### 📚 Citation
```bibtex
@article{wu2025visualquality,
title={{VisualQuality-R1}: Reasoning-Induced Image Quality Assessment via Reinforcement Learning to Rank},
author={Wu, Tianhe and Zou, Jian and Liang, Jie and Zhang, Lei and Ma, Kede},
journal={arXiv preprint arXiv:2505.14460},
year={2025}
}
```
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.queue(max_size=10)
# Добавлены параметры для Gradio 6.0
demo.launch(
ssr_mode=False, # Отключаем SSR для стабильности
show_error=True,
)