Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig, TextIteratorStreamer | |
| from qwen_vl_utils import process_vision_info | |
| from threading import Thread | |
| import re | |
| import spaces | |
| # Константы | |
| MODEL_PATH = "TianheWu/VisualQuality-R1-7B" | |
| # Промпты | |
| PROMPT = ( | |
| "You are doing the image quality assessment task. Here is the question: " | |
| "What is your overall rating on the quality of this picture? The rating should be a float between 1 and 5, " | |
| "rounded to two decimal places, with 1 representing very poor quality and 5 representing excellent quality." | |
| ) | |
| QUESTION_TEMPLATE_THINKING = "{Question} First output the thinking process in <think> </think> tags and then output the final answer with only one score in <answer> </answer> tags." | |
| QUESTION_TEMPLATE_NO_THINKING = "{Question} Please only output the final answer with only one score in <answer> </answer> tags." | |
| # Глобальные переменные для модели | |
| model = None | |
| processor = None | |
| def load_model(): | |
| """Загрузка модели с 8-bit квантизацией""" | |
| global model, processor | |
| if model is not None: | |
| return | |
| print("Loading model...") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, | |
| llm_int8_threshold=6.0, | |
| ) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_PATH, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| model.eval() | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) | |
| processor.tokenizer.padding_side = "left" | |
| print("Model loaded successfully!") | |
| def extract_score(text): | |
| """Извлечение оценки из текста""" | |
| try: | |
| model_output_matches = re.findall(r'<answer>(.*?)</answer>', text, re.DOTALL) | |
| if model_output_matches: | |
| model_answer = model_output_matches[-1].strip() | |
| else: | |
| model_answer = text.strip() | |
| score_match = re.search(r'\d+(\.\d+)?', model_answer) | |
| if score_match: | |
| score = float(score_match.group()) | |
| return min(max(score, 1.0), 5.0) | |
| except Exception as e: | |
| print(f"Error extracting score: {e}") | |
| return None | |
| def extract_thinking(text): | |
| """Извлечение процесса мышления из текста""" | |
| thinking_matches = re.findall(r'<think>(.*?)</think>', text, re.DOTALL) | |
| if thinking_matches: | |
| return thinking_matches[-1].strip() | |
| return None | |
| def score_image_streaming(image, use_thinking=True): | |
| """Оценка качества изображения со стримингом""" | |
| global model, processor | |
| # Загрузка модели при первом вызове | |
| load_model() | |
| if image is None: | |
| yield "❌ Please upload an image first.", "", "" | |
| return | |
| # Выбор шаблона | |
| if use_thinking: | |
| question_template = QUESTION_TEMPLATE_THINKING | |
| else: | |
| question_template = QUESTION_TEMPLATE_NO_THINKING | |
| # Формирование сообщения | |
| message = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {'type': 'image', 'image': image}, | |
| {"type": "text", "text": question_template.format(Question=PROMPT)} | |
| ], | |
| } | |
| ] | |
| batch_messages = [message] | |
| # Подготовка входных данных | |
| text = [processor.apply_chat_template( | |
| msg, tokenize=False, add_generation_prompt=True, add_vision_id=True | |
| ) for msg in batch_messages] | |
| image_inputs, video_inputs = process_vision_info(batch_messages) | |
| inputs = processor( | |
| text=text, | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(model.device) | |
| # Настройка стриминга | |
| streamer = TextIteratorStreamer( | |
| processor.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=2048 if use_thinking else 256, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7, | |
| use_cache=True, | |
| ) | |
| # Запуск генерации в отдельном потоке | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Стриминг вывода | |
| generated_text = "" | |
| current_thinking = "" | |
| current_score = "*Analyzing...*" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| # Извлечение мышления (если есть) | |
| thinking = extract_thinking(generated_text) | |
| if thinking: | |
| current_thinking = thinking | |
| # Извлечение оценки | |
| score = extract_score(generated_text) | |
| if score is not None: | |
| current_score = f"⭐ **Quality Score: {score:.2f} / 5.00**" | |
| yield generated_text, current_thinking, current_score | |
| thread.join() | |
| # Финальное извлечение | |
| final_score = extract_score(generated_text) | |
| final_thinking = extract_thinking(generated_text) if use_thinking else "" | |
| if final_score is not None: | |
| score_display = f"⭐ **Quality Score: {final_score:.2f} / 5.00**\n\n📊 **For Leaderboard:** `{final_score:.2f}`" | |
| else: | |
| score_display = "❌ Could not extract score. Please try again." | |
| yield generated_text, final_thinking or "", score_display | |
| def create_interface(): | |
| """Создание интерфейса Gradio""" | |
| # Убрали theme из gr.Blocks() - теперь он передаётся в launch() | |
| with gr.Blocks( | |
| title="VisualQuality-R1: Image Quality Assessment", | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🎨 VisualQuality-R1: Image Quality Assessment | |
| **Reasoning-Induced Image Quality Assessment via Reinforcement Learning to Rank** | |
| Upload an image to get a quality score (1-5) with detailed reasoning. | |
| [](https://arxiv.org/abs/2505.14460) | |
| [](https://huggingface.co/TianheWu/VisualQuality-R1-7B) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image( | |
| label="📷 Upload Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| thinking_checkbox = gr.Checkbox( | |
| label="🧠 Enable Thinking Mode (detailed reasoning)", | |
| value=True | |
| ) | |
| submit_btn = gr.Button( | |
| "🔍 Analyze Image Quality", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown(""" | |
| ### 📖 Instructions: | |
| 1. Upload an image | |
| 2. Enable/disable thinking mode | |
| 3. Click "Analyze Image Quality" | |
| 4. Wait for the score and reasoning | |
| ### 📊 Score Scale: | |
| - **1.0**: Very poor quality | |
| - **2.0**: Poor quality | |
| - **3.0**: Fair quality | |
| - **4.0**: Good quality | |
| - **5.0**: Excellent quality | |
| """) | |
| with gr.Column(scale=1): | |
| score_output = gr.Markdown( | |
| label="Quality Score", | |
| value="*Upload an image to see the score*" | |
| ) | |
| thinking_output = gr.Textbox( | |
| label="🧠 Thinking Process", | |
| lines=8, | |
| max_lines=15, | |
| placeholder="Reasoning will appear here when thinking mode is enabled...", | |
| interactive=False | |
| ) | |
| raw_output = gr.Textbox( | |
| label="📝 Full Model Output", | |
| lines=10, | |
| max_lines=20, | |
| placeholder="Full model response will appear here...", | |
| interactive=False | |
| ) | |
| # Обработка события | |
| submit_btn.click( | |
| fn=score_image_streaming, | |
| inputs=[image_input, thinking_checkbox], | |
| outputs=[raw_output, thinking_output, score_output], | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### 📚 Citation | |
| ```bibtex | |
| @article{wu2025visualquality, | |
| title={{VisualQuality-R1}: Reasoning-Induced Image Quality Assessment via Reinforcement Learning to Rank}, | |
| author={Wu, Tianhe and Zou, Jian and Liang, Jie and Zhang, Lei and Ma, Kede}, | |
| journal={arXiv preprint arXiv:2505.14460}, | |
| year={2025} | |
| } | |
| ``` | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.queue(max_size=10) | |
| # Добавлены параметры для Gradio 6.0 | |
| demo.launch( | |
| ssr_mode=False, # Отключаем SSR для стабильности | |
| show_error=True, | |
| ) |