Spaces:
Running on Zero
Running on Zero
| # Copyright (c) iMED | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| os.environ['TORCHDYNAMO_DISABLE'] = "1" | |
| import sys | |
| import copy | |
| import re | |
| from argparse import ArgumentParser | |
| from threading import Thread | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Try to import spaces, define placeholder decorator if failed | |
| try: | |
| import spaces | |
| HAS_SPACES = True | |
| print(f"HAS_SPACES: {HAS_SPACES}") | |
| except ImportError: | |
| HAS_SPACES = False | |
| print(f"HAS_SPACES: {HAS_SPACES}") | |
| class spaces: | |
| def GPU(func=None, **kwargs): | |
| if func: | |
| return func | |
| return lambda f: f | |
| # Default model checkpoint path | |
| DEFAULT_CKPT_PATH = 'qiuxi337/IntrinSight-4B' | |
| # Default system prompt | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| "A conversation between user and assistant. The user asks a question, and the assistant solves it. The assistant " | |
| "first thinks about the reasoning process in the mind and then provides the user with the answer. " | |
| "The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., " | |
| "<think> reasoning process here </think><answer> answer here </answer>." | |
| ) | |
| # Pan and Scan default settings | |
| DEFAULT_PAN_SCAN_MAX_CROPS = 2 | |
| DEFAULT_PAN_SCAN_MIN_RATIO = 1.5 | |
| DEFAULT_PAN_SCAN_MIN_CROP_SIZE = 224 | |
| # CSS styles | |
| CUSTOM_CSS = """ | |
| .container { | |
| max-width: 1400px; | |
| margin: 0 auto; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .main-title { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 3em; | |
| font-weight: bold; | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| .sub-title { | |
| text-align: center; | |
| color: #666; | |
| font-size: 1.2em; | |
| margin-bottom: 30px; | |
| } | |
| .control-height { | |
| border-radius: 15px; | |
| border: 1px solid #e0e0e0; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.1); | |
| } | |
| .custom-button { | |
| border-radius: 8px; | |
| font-weight: 500; | |
| transition: all 0.3s ease; | |
| } | |
| textarea { | |
| border-radius: 10px !important; | |
| border: 1px solid #d0d0d0 !important; | |
| padding: 10px !important; | |
| font-size: 14px !important; | |
| } | |
| .parameter-section { | |
| background: #f5f5f5; | |
| border-radius: 10px; | |
| padding: 15px; | |
| margin-bottom: 15px; | |
| } | |
| """ | |
| def _get_args(): | |
| """Parse command line arguments""" | |
| parser = ArgumentParser() | |
| parser.add_argument('-c', '--checkpoint-path', | |
| type=str, | |
| default=DEFAULT_CKPT_PATH, | |
| help='Checkpoint name or path') | |
| parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only') | |
| parser.add_argument('--share', action='store_true', default=False) | |
| parser.add_argument('--inbrowser', action='store_true', default=False) | |
| parser.add_argument('--server-port', type=int, default=7860) | |
| parser.add_argument('--server-name', type=str, default='0.0.0.0') | |
| args = parser.parse_args() | |
| return args | |
| # ============================================================================ | |
| # GLOBAL MODEL STATE - Load lazily inside GPU-decorated functions | |
| # ============================================================================ | |
| _model = None | |
| _processor = None | |
| _device = None | |
| def _get_model_and_processor(checkpoint_path): | |
| """ | |
| Lazy-load model and processor. | |
| This should be called from within a @spaces.GPU decorated function. | |
| """ | |
| global _model, _processor, _device | |
| if _model is not None: | |
| return _model, _processor, _device | |
| print(f"{'='*50}") | |
| print(f"π Loading model: {checkpoint_path}") | |
| use_gpu = torch.cuda.is_available() | |
| _device = 'cuda' if use_gpu else 'cpu' | |
| print(f"π± Device: {'GPU (CUDA)' if use_gpu else 'CPU'}") | |
| print(f"{'='*50}") | |
| model_kwargs = { | |
| 'pretrained_model_name_or_path': checkpoint_path, | |
| 'torch_dtype': torch.bfloat16 if use_gpu else torch.float32, | |
| 'low_cpu_mem_usage': True, | |
| } | |
| if use_gpu: | |
| model_kwargs['device_map'] = 'auto' | |
| else: | |
| model_kwargs['device_map'] = None | |
| try: | |
| _model = AutoModelForImageTextToText.from_pretrained(**model_kwargs) | |
| _model.eval() | |
| if not use_gpu: | |
| _model = _model.to(_device) | |
| except Exception as e: | |
| print(f"β οΈ Failed to load model: {e}") | |
| print("π Falling back to CPU mode with float32...") | |
| model_kwargs = { | |
| 'pretrained_model_name_or_path': checkpoint_path, | |
| 'torch_dtype': torch.float32, | |
| 'device_map': None, | |
| 'low_cpu_mem_usage': True | |
| } | |
| _model = AutoModelForImageTextToText.from_pretrained(**model_kwargs) | |
| _model = _model.to('cpu') | |
| _model.eval() | |
| _device = 'cpu' | |
| _processor = AutoProcessor.from_pretrained(checkpoint_path) | |
| print(f"β Model loaded successfully on {_device}") | |
| return _model, _processor, _device | |
| def encode_image_pil(image_path): | |
| """Encode image to base64 using PIL with memory-efficient resizing""" | |
| try: | |
| if isinstance(image_path, str): | |
| img = Image.open(image_path) | |
| elif isinstance(image_path, np.ndarray): | |
| img = Image.fromarray(image_path) | |
| elif isinstance(image_path, Image.Image): | |
| img = image_path | |
| else: | |
| print(f"Unsupported image type: {type(image_path)}") | |
| return None | |
| if img.mode not in ('RGB', 'RGBA'): | |
| img = img.convert('RGB') | |
| elif img.mode == 'RGBA': | |
| img = img.convert('RGB') | |
| max_size = (768, 768) | |
| img.thumbnail(max_size, Image.Resampling.LANCZOS) | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG", quality=85) | |
| return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| except Exception as e: | |
| print(f"Error encoding image: {e}") | |
| return None | |
| def _parse_text(text): | |
| """Parse text for display formatting""" | |
| if text is None: | |
| return "" | |
| text = str(text) | |
| lines = text.split('\n') | |
| lines = [line for line in lines if line != ''] | |
| count = 0 | |
| for i, line in enumerate(lines): | |
| if "<think>" in line: | |
| line = line.replace("<think>", "**Reasoning Process**:\n") | |
| if "</think>" in line: | |
| line = line.replace("</think>", "") | |
| if "<answer>" in line: | |
| line = line.replace("<answer>", "**Final Answer**:\n") | |
| if "</answer>" in line: | |
| line = line.replace("</answer>", "") | |
| if '```' in line: | |
| count += 1 | |
| items = line.split('`') | |
| if count % 2 == 1: | |
| lines[i] = f'<pre><code class="language-{items[-1]}">' | |
| else: | |
| lines[i] = '<br></code></pre>' | |
| else: | |
| if i > 0: | |
| if count % 2 == 1: | |
| line = line.replace('`', r'\`') | |
| line = line.replace('<', '<') | |
| line = line.replace('>', '>') | |
| line = line.replace(' ', ' ') | |
| line = line.replace('*', '*') | |
| line = line.replace('_', '_') | |
| line = line.replace('-', '-') | |
| line = line.replace('.', '.') | |
| line = line.replace('!', '!') | |
| line = line.replace('(', '(') | |
| line = line.replace(')', ')') | |
| line = line.replace('$', '$') | |
| lines[i] = '<br>' + line | |
| text = ''.join(lines) | |
| return text | |
| def _remove_image_special(text): | |
| """Remove special image tags from text""" | |
| if text is None: | |
| return "" | |
| text = text.replace('<ref>', '').replace('</ref>', '') | |
| return re.sub(r'<box>.*?(</box>|$)', '', text) | |
| def _gc(): | |
| """Garbage collection to free memory""" | |
| import gc | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def _transform_messages(original_messages, system_prompt): | |
| """Transform messages with custom system prompt""" | |
| transformed_messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] | |
| for message in original_messages: | |
| new_content = [] | |
| for item in message['content']: | |
| if 'image' in item: | |
| new_content.append({'type': 'image', 'image': item['image']}) | |
| elif 'text' in item: | |
| new_content.append({'type': 'text', 'text': item['text']}) | |
| if new_content: | |
| transformed_messages.append({'role': message['role'], 'content': new_content}) | |
| return transformed_messages | |
| def normalize_task_history_item(item): | |
| """Normalize items in task_history to a dictionary format""" | |
| if isinstance(item, dict): | |
| return {'text': item.get('text', ''), 'images': item.get('images', []), 'response': item.get('response', None)} | |
| elif isinstance(item, (list, tuple)) and len(item) >= 2: | |
| query, response = item[0], item[1] | |
| if isinstance(query, (list, tuple)): | |
| return {'text': '', 'images': list(query), 'response': response} | |
| else: | |
| return {'text': str(query) if query else '', 'images': [], 'response': response} | |
| else: | |
| return {'text': str(item) if item else '', 'images': [], 'response': None} | |
| def _launch_demo(args): | |
| """Launch the Gradio demo interface""" | |
| def call_local_model(messages, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio): | |
| """Call the local model with streaming response - loads model lazily""" | |
| model, processor, device = _get_model_and_processor(args.checkpoint_path) | |
| messages = _transform_messages(messages, system_prompt) | |
| # Apply pan and scan settings dynamically at inference time | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| do_pan_and_scan=do_pan_and_scan, | |
| pan_and_scan_max_num_crops=pan_scan_max_crops, | |
| pan_and_scan_min_ratio_to_activate=pan_scan_min_ratio, | |
| pan_and_scan_min_crop_size=DEFAULT_PAN_SCAN_MIN_CROP_SIZE, | |
| ) | |
| inputs = inputs.to(device) | |
| tokenizer = processor.tokenizer | |
| streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = { | |
| 'max_new_tokens': max_tokens, | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": 20, | |
| 'streamer': streamer, | |
| **inputs | |
| } | |
| with torch.inference_mode(): | |
| thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| generated_text = '' | |
| for new_text in streamer: | |
| generated_text += new_text | |
| display_text = generated_text | |
| if "<think>" in display_text: | |
| display_text = display_text.replace("<think>", "**Reasoning Process**:\n") | |
| if "</think>" in display_text: | |
| display_text = display_text.replace("</think>", "\n") | |
| if "<answer>" in display_text: | |
| display_text = display_text.replace("<answer>", "**Final Answer**:\n") | |
| if "</answer>" in display_text: | |
| display_text = display_text.replace("</answer>", "") | |
| yield display_text, generated_text | |
| def predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio): | |
| if not _chatbot or not task_history: | |
| yield _chatbot | |
| return | |
| chat_query = _chatbot[-1][0] | |
| last_item = normalize_task_history_item(task_history[-1]) | |
| if not chat_query and not last_item['text'] and not last_item['images']: | |
| _chatbot.pop() | |
| task_history.pop() | |
| yield _chatbot | |
| return | |
| print(f'User query: {last_item}') | |
| history_cp = [normalize_task_history_item(item) for item in copy.deepcopy(task_history)] | |
| full_response_raw = '' | |
| messages = [] | |
| # Limit history to last 3 turns to save memory | |
| history_cp = history_cp[-3:] | |
| for i, item in enumerate(history_cp): | |
| content = [] | |
| # Process all images without limitation | |
| if item['images']: | |
| for img_path in item['images']: | |
| if img_path: | |
| encoded_img = encode_image_pil(img_path) | |
| if encoded_img: | |
| content.append({'image': encoded_img}) | |
| if item['text']: | |
| content.append({'text': str(item['text'])}) | |
| if item['response'] is None: | |
| if content: | |
| messages.append({'role': 'user', 'content': content}) | |
| else: | |
| if content: | |
| messages.append({'role': 'user', 'content': content}) | |
| messages.append({'role': 'assistant', 'content': [{'text': str(item['response'])}]}) | |
| try: | |
| for response_display, response_raw in call_local_model( | |
| messages, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio | |
| ): | |
| _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response_display))) | |
| yield _chatbot | |
| full_response_raw = response_raw | |
| task_history[-1]['response'] = full_response_raw | |
| print(f'Assistant: {full_response_raw}') | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| error_msg = f"Error: {str(e)}" | |
| _chatbot[-1] = (_parse_text(chat_query), error_msg) | |
| task_history[-1]['response'] = error_msg | |
| yield _chatbot | |
| def regenerate(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio): | |
| if not task_history or not _chatbot: | |
| yield _chatbot | |
| return | |
| last_item = normalize_task_history_item(task_history[-1]) | |
| if last_item['response'] is None: | |
| yield _chatbot | |
| return | |
| last_item['response'] = None | |
| task_history[-1] = last_item | |
| _chatbot.pop(-1) | |
| display_message_parts = [] | |
| if last_item['images']: | |
| display_message_parts.append(f"[Uploaded {len(last_item['images'])} image(s)]") | |
| if last_item['text']: | |
| display_message_parts.append(last_item['text']) | |
| display_message = " ".join(display_message_parts) | |
| _chatbot.append([_parse_text(display_message), None]) | |
| for updated_chatbot in predict( | |
| _chatbot, task_history, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio | |
| ): | |
| yield updated_chatbot | |
| def add_text_and_files(history, task_history, text, files): | |
| history = history if history is not None else [] | |
| task_history = task_history if task_history is not None else [] | |
| has_text = text and text.strip() | |
| has_files = files and len(files) > 0 | |
| if not has_text and not has_files: | |
| return history, task_history, text, files | |
| display_parts, file_paths = [], [] | |
| if has_files: | |
| for file in files: | |
| if file and hasattr(file, 'name'): | |
| file_paths.append(file.name) | |
| if file_paths: | |
| display_parts.append(f"[Uploaded {len(file_paths)} image(s)]") | |
| if has_text: | |
| display_parts.append(text) | |
| display_message = " ".join(display_parts) | |
| history.append([_parse_text(display_message), None]) | |
| task_history.append({'text': text if has_text else '', 'images': file_paths, 'response': None}) | |
| return history, task_history, '', None | |
| def reset_state(): | |
| _gc() | |
| return [], [], None | |
| with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div class="container"> | |
| <h1 class="main-title">IntrinSight Assistant</h1> | |
| <p class="sub-title"> | |
| Powered by IntrinSight-4B Model (ZeroGPU) | |
| </p> | |
| </div> | |
| """) | |
| task_history = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| label='IntrinSight-4B Chat Interface', | |
| elem_classes='control-height', | |
| height=600, | |
| avatar_images=(None, "https://em-content.zobj.net/thumbs/240/twitter/348/robot_1f916.png") | |
| ) | |
| with gr.Row(): | |
| query = gr.Textbox( | |
| lines=3, | |
| label='π¬ Message Input', | |
| placeholder="Enter your question here...", | |
| elem_classes="custom-input" | |
| ) | |
| with gr.Row(): | |
| addfile_btn = gr.File( | |
| label="πΈ Upload Images", | |
| file_count="multiple", | |
| file_types=["image"], | |
| elem_classes="file-upload-area" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button('π Send', variant="primary", elem_classes="custom-button") | |
| regen_btn = gr.Button('π Regenerate', variant="secondary", elem_classes="custom-button") | |
| empty_bin = gr.Button('ποΈ Clear History', variant="stop", elem_classes="custom-button") | |
| with gr.Column(scale=2): | |
| with gr.Group(elem_classes="parameter-section"): | |
| gr.Markdown("### βοΈ System Configuration") | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| lines=5, | |
| placeholder="Enter system prompt here..." | |
| ) | |
| with gr.Group(elem_classes="parameter-section"): | |
| gr.Markdown("### ποΈ Generation Parameters") | |
| temperature = gr.Slider( | |
| minimum=0.1, maximum=2.0, value=0.7, step=0.1, | |
| label="Temperature", | |
| info="Higher values make output more random" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=1.0, step=0.05, | |
| label="Top-p", | |
| info="Cumulative probability for token selection" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=256, maximum=32768, value=8192, step=256, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| with gr.Group(elem_classes="parameter-section"): | |
| gr.Markdown("### πΌοΈ Image Processing (Pan & Scan)") | |
| do_pan_and_scan = gr.Checkbox( | |
| label="Enable Pan & Scan", | |
| value=True, | |
| info="Split large images into crops for better detail" | |
| ) | |
| pan_scan_max_crops = gr.Slider( | |
| minimum=1, maximum=6, value=DEFAULT_PAN_SCAN_MAX_CROPS, step=1, | |
| label="Max Crops", | |
| info="More crops = better detail but higher memory usage" | |
| ) | |
| pan_scan_min_ratio = gr.Slider( | |
| minimum=1.0, maximum=3.0, value=DEFAULT_PAN_SCAN_MIN_RATIO, step=0.1, | |
| label="Min Ratio to Activate", | |
| info="Aspect ratio threshold to trigger pan & scan" | |
| ) | |
| gr.Markdown(""" | |
| ### π Instructions | |
| **Usage:** | |
| - Enter your question and click Send | |
| - Upload multiple images as needed | |
| - First request may take longer (model loading) | |
| **Memory Tips:** | |
| - Reduce "Max Crops" if you encounter memory errors | |
| - Disable "Pan & Scan" for lower memory usage | |
| - More images = higher memory consumption | |
| ### β οΈ Disclaimer | |
| Subject to Gemma license agreement. | |
| """) | |
| submit_btn.click( | |
| add_text_and_files, | |
| [chatbot, task_history, query, addfile_btn], | |
| [chatbot, task_history, query, addfile_btn] | |
| ).then( | |
| predict, | |
| [chatbot, task_history, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio], | |
| [chatbot], | |
| show_progress="full" | |
| ) | |
| empty_bin.click(reset_state, outputs=[chatbot, task_history, addfile_btn], show_progress=True) | |
| regen_btn.click( | |
| regenerate, | |
| [chatbot, task_history, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio], | |
| [chatbot], | |
| show_progress="full" | |
| ) | |
| query.submit( | |
| add_text_and_files, | |
| [chatbot, task_history, query, addfile_btn], | |
| [chatbot, task_history, query, addfile_btn] | |
| ).then( | |
| predict, | |
| [chatbot, task_history, system_prompt, temperature, top_p, max_tokens, | |
| do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio], | |
| [chatbot], | |
| show_progress="full" | |
| ) | |
| demo.queue(max_size=5).launch( | |
| share=args.share, | |
| inbrowser=args.inbrowser, | |
| server_port=args.server_port, | |
| server_name=args.server_name, | |
| show_error=True | |
| ) | |
| def main(): | |
| """Main entry point""" | |
| args = _get_args() | |
| _launch_demo(args) | |
| if __name__ == '__main__': | |
| main() |