qiuxi337's picture
Update app.py
26412dd verified
# Copyright (c) iMED
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
os.environ['TORCHDYNAMO_DISABLE'] = "1"
import sys
import copy
import re
from argparse import ArgumentParser
from threading import Thread
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
import numpy as np
from PIL import Image
import io
import base64
import warnings
warnings.filterwarnings('ignore')
# Try to import spaces, define placeholder decorator if failed
try:
import spaces
HAS_SPACES = True
print(f"HAS_SPACES: {HAS_SPACES}")
except ImportError:
HAS_SPACES = False
print(f"HAS_SPACES: {HAS_SPACES}")
class spaces:
@staticmethod
def GPU(func=None, **kwargs):
if func:
return func
return lambda f: f
# Default model checkpoint path
DEFAULT_CKPT_PATH = 'qiuxi337/IntrinSight-4B'
# Default system prompt
DEFAULT_SYSTEM_PROMPT = (
"A conversation between user and assistant. The user asks a question, and the assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. "
"The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>."
)
# Pan and Scan default settings
DEFAULT_PAN_SCAN_MAX_CROPS = 2
DEFAULT_PAN_SCAN_MIN_RATIO = 1.5
DEFAULT_PAN_SCAN_MIN_CROP_SIZE = 224
# CSS styles
CUSTOM_CSS = """
.container {
max-width: 1400px;
margin: 0 auto;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.main-title {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 3em;
font-weight: bold;
text-align: center;
margin-bottom: 10px;
}
.sub-title {
text-align: center;
color: #666;
font-size: 1.2em;
margin-bottom: 30px;
}
.control-height {
border-radius: 15px;
border: 1px solid #e0e0e0;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.custom-button {
border-radius: 8px;
font-weight: 500;
transition: all 0.3s ease;
}
textarea {
border-radius: 10px !important;
border: 1px solid #d0d0d0 !important;
padding: 10px !important;
font-size: 14px !important;
}
.parameter-section {
background: #f5f5f5;
border-radius: 10px;
padding: 15px;
margin-bottom: 15px;
}
"""
def _get_args():
"""Parse command line arguments"""
parser = ArgumentParser()
parser.add_argument('-c', '--checkpoint-path',
type=str,
default=DEFAULT_CKPT_PATH,
help='Checkpoint name or path')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--share', action='store_true', default=False)
parser.add_argument('--inbrowser', action='store_true', default=False)
parser.add_argument('--server-port', type=int, default=7860)
parser.add_argument('--server-name', type=str, default='0.0.0.0')
args = parser.parse_args()
return args
# ============================================================================
# GLOBAL MODEL STATE - Load lazily inside GPU-decorated functions
# ============================================================================
_model = None
_processor = None
_device = None
def _get_model_and_processor(checkpoint_path):
"""
Lazy-load model and processor.
This should be called from within a @spaces.GPU decorated function.
"""
global _model, _processor, _device
if _model is not None:
return _model, _processor, _device
print(f"{'='*50}")
print(f"πŸš€ Loading model: {checkpoint_path}")
use_gpu = torch.cuda.is_available()
_device = 'cuda' if use_gpu else 'cpu'
print(f"πŸ“± Device: {'GPU (CUDA)' if use_gpu else 'CPU'}")
print(f"{'='*50}")
model_kwargs = {
'pretrained_model_name_or_path': checkpoint_path,
'torch_dtype': torch.bfloat16 if use_gpu else torch.float32,
'low_cpu_mem_usage': True,
}
if use_gpu:
model_kwargs['device_map'] = 'auto'
else:
model_kwargs['device_map'] = None
try:
_model = AutoModelForImageTextToText.from_pretrained(**model_kwargs)
_model.eval()
if not use_gpu:
_model = _model.to(_device)
except Exception as e:
print(f"⚠️ Failed to load model: {e}")
print("πŸ”„ Falling back to CPU mode with float32...")
model_kwargs = {
'pretrained_model_name_or_path': checkpoint_path,
'torch_dtype': torch.float32,
'device_map': None,
'low_cpu_mem_usage': True
}
_model = AutoModelForImageTextToText.from_pretrained(**model_kwargs)
_model = _model.to('cpu')
_model.eval()
_device = 'cpu'
_processor = AutoProcessor.from_pretrained(checkpoint_path)
print(f"βœ… Model loaded successfully on {_device}")
return _model, _processor, _device
def encode_image_pil(image_path):
"""Encode image to base64 using PIL with memory-efficient resizing"""
try:
if isinstance(image_path, str):
img = Image.open(image_path)
elif isinstance(image_path, np.ndarray):
img = Image.fromarray(image_path)
elif isinstance(image_path, Image.Image):
img = image_path
else:
print(f"Unsupported image type: {type(image_path)}")
return None
if img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB')
elif img.mode == 'RGBA':
img = img.convert('RGB')
max_size = (768, 768)
img.thumbnail(max_size, Image.Resampling.LANCZOS)
buffered = io.BytesIO()
img.save(buffered, format="JPEG", quality=85)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
except Exception as e:
print(f"Error encoding image: {e}")
return None
def _parse_text(text):
"""Parse text for display formatting"""
if text is None:
return ""
text = str(text)
lines = text.split('\n')
lines = [line for line in lines if line != '']
count = 0
for i, line in enumerate(lines):
if "<think>" in line:
line = line.replace("<think>", "**Reasoning Process**:\n")
if "</think>" in line:
line = line.replace("</think>", "")
if "<answer>" in line:
line = line.replace("<answer>", "**Final Answer**:\n")
if "</answer>" in line:
line = line.replace("</answer>", "")
if '```' in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = '<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace('`', r'\`')
line = line.replace('<', '&lt;')
line = line.replace('>', '&gt;')
line = line.replace(' ', '&nbsp;')
line = line.replace('*', '&ast;')
line = line.replace('_', '&lowbar;')
line = line.replace('-', '&#45;')
line = line.replace('.', '&#46;')
line = line.replace('!', '&#33;')
line = line.replace('(', '&#40;')
line = line.replace(')', '&#41;')
line = line.replace('$', '&#36;')
lines[i] = '<br>' + line
text = ''.join(lines)
return text
def _remove_image_special(text):
"""Remove special image tags from text"""
if text is None:
return ""
text = text.replace('<ref>', '').replace('</ref>', '')
return re.sub(r'<box>.*?(</box>|$)', '', text)
def _gc():
"""Garbage collection to free memory"""
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _transform_messages(original_messages, system_prompt):
"""Transform messages with custom system prompt"""
transformed_messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
for message in original_messages:
new_content = []
for item in message['content']:
if 'image' in item:
new_content.append({'type': 'image', 'image': item['image']})
elif 'text' in item:
new_content.append({'type': 'text', 'text': item['text']})
if new_content:
transformed_messages.append({'role': message['role'], 'content': new_content})
return transformed_messages
def normalize_task_history_item(item):
"""Normalize items in task_history to a dictionary format"""
if isinstance(item, dict):
return {'text': item.get('text', ''), 'images': item.get('images', []), 'response': item.get('response', None)}
elif isinstance(item, (list, tuple)) and len(item) >= 2:
query, response = item[0], item[1]
if isinstance(query, (list, tuple)):
return {'text': '', 'images': list(query), 'response': response}
else:
return {'text': str(query) if query else '', 'images': [], 'response': response}
else:
return {'text': str(item) if item else '', 'images': [], 'response': None}
def _launch_demo(args):
"""Launch the Gradio demo interface"""
@spaces.GPU(duration=360)
def call_local_model(messages, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio):
"""Call the local model with streaming response - loads model lazily"""
model, processor, device = _get_model_and_processor(args.checkpoint_path)
messages = _transform_messages(messages, system_prompt)
# Apply pan and scan settings dynamically at inference time
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
do_pan_and_scan=do_pan_and_scan,
pan_and_scan_max_num_crops=pan_scan_max_crops,
pan_and_scan_min_ratio_to_activate=pan_scan_min_ratio,
pan_and_scan_min_crop_size=DEFAULT_PAN_SCAN_MIN_CROP_SIZE,
)
inputs = inputs.to(device)
tokenizer = processor.tokenizer
streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
'max_new_tokens': max_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": 20,
'streamer': streamer,
**inputs
}
with torch.inference_mode():
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
generated_text = ''
for new_text in streamer:
generated_text += new_text
display_text = generated_text
if "<think>" in display_text:
display_text = display_text.replace("<think>", "**Reasoning Process**:\n")
if "</think>" in display_text:
display_text = display_text.replace("</think>", "\n")
if "<answer>" in display_text:
display_text = display_text.replace("<answer>", "**Final Answer**:\n")
if "</answer>" in display_text:
display_text = display_text.replace("</answer>", "")
yield display_text, generated_text
@spaces.GPU(duration=360)
def predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio):
if not _chatbot or not task_history:
yield _chatbot
return
chat_query = _chatbot[-1][0]
last_item = normalize_task_history_item(task_history[-1])
if not chat_query and not last_item['text'] and not last_item['images']:
_chatbot.pop()
task_history.pop()
yield _chatbot
return
print(f'User query: {last_item}')
history_cp = [normalize_task_history_item(item) for item in copy.deepcopy(task_history)]
full_response_raw = ''
messages = []
# Limit history to last 3 turns to save memory
history_cp = history_cp[-3:]
for i, item in enumerate(history_cp):
content = []
# Process all images without limitation
if item['images']:
for img_path in item['images']:
if img_path:
encoded_img = encode_image_pil(img_path)
if encoded_img:
content.append({'image': encoded_img})
if item['text']:
content.append({'text': str(item['text'])})
if item['response'] is None:
if content:
messages.append({'role': 'user', 'content': content})
else:
if content:
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{'text': str(item['response'])}]})
try:
for response_display, response_raw in call_local_model(
messages, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio
):
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response_display)))
yield _chatbot
full_response_raw = response_raw
task_history[-1]['response'] = full_response_raw
print(f'Assistant: {full_response_raw}')
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
error_msg = f"Error: {str(e)}"
_chatbot[-1] = (_parse_text(chat_query), error_msg)
task_history[-1]['response'] = error_msg
yield _chatbot
@spaces.GPU(duration=360)
def regenerate(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio):
if not task_history or not _chatbot:
yield _chatbot
return
last_item = normalize_task_history_item(task_history[-1])
if last_item['response'] is None:
yield _chatbot
return
last_item['response'] = None
task_history[-1] = last_item
_chatbot.pop(-1)
display_message_parts = []
if last_item['images']:
display_message_parts.append(f"[Uploaded {len(last_item['images'])} image(s)]")
if last_item['text']:
display_message_parts.append(last_item['text'])
display_message = " ".join(display_message_parts)
_chatbot.append([_parse_text(display_message), None])
for updated_chatbot in predict(
_chatbot, task_history, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio
):
yield updated_chatbot
def add_text_and_files(history, task_history, text, files):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
has_text = text and text.strip()
has_files = files and len(files) > 0
if not has_text and not has_files:
return history, task_history, text, files
display_parts, file_paths = [], []
if has_files:
for file in files:
if file and hasattr(file, 'name'):
file_paths.append(file.name)
if file_paths:
display_parts.append(f"[Uploaded {len(file_paths)} image(s)]")
if has_text:
display_parts.append(text)
display_message = " ".join(display_parts)
history.append([_parse_text(display_message), None])
task_history.append({'text': text if has_text else '', 'images': file_paths, 'response': None})
return history, task_history, '', None
def reset_state():
_gc()
return [], [], None
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div class="container">
<h1 class="main-title">IntrinSight Assistant</h1>
<p class="sub-title">
Powered by IntrinSight-4B Model (ZeroGPU)
</p>
</div>
""")
task_history = gr.State([])
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(
label='IntrinSight-4B Chat Interface',
elem_classes='control-height',
height=600,
avatar_images=(None, "https://em-content.zobj.net/thumbs/240/twitter/348/robot_1f916.png")
)
with gr.Row():
query = gr.Textbox(
lines=3,
label='πŸ’¬ Message Input',
placeholder="Enter your question here...",
elem_classes="custom-input"
)
with gr.Row():
addfile_btn = gr.File(
label="πŸ“Έ Upload Images",
file_count="multiple",
file_types=["image"],
elem_classes="file-upload-area"
)
with gr.Row():
submit_btn = gr.Button('πŸš€ Send', variant="primary", elem_classes="custom-button")
regen_btn = gr.Button('πŸ”„ Regenerate', variant="secondary", elem_classes="custom-button")
empty_bin = gr.Button('πŸ—‘οΈ Clear History', variant="stop", elem_classes="custom-button")
with gr.Column(scale=2):
with gr.Group(elem_classes="parameter-section"):
gr.Markdown("### βš™οΈ System Configuration")
system_prompt = gr.Textbox(
label="System Prompt",
value=DEFAULT_SYSTEM_PROMPT,
lines=5,
placeholder="Enter system prompt here..."
)
with gr.Group(elem_classes="parameter-section"):
gr.Markdown("### πŸŽ›οΈ Generation Parameters")
temperature = gr.Slider(
minimum=0.1, maximum=2.0, value=0.7, step=0.1,
label="Temperature",
info="Higher values make output more random"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=1.0, step=0.05,
label="Top-p",
info="Cumulative probability for token selection"
)
max_tokens = gr.Slider(
minimum=256, maximum=32768, value=8192, step=256,
label="Max Tokens",
info="Maximum number of tokens to generate"
)
with gr.Group(elem_classes="parameter-section"):
gr.Markdown("### πŸ–ΌοΈ Image Processing (Pan & Scan)")
do_pan_and_scan = gr.Checkbox(
label="Enable Pan & Scan",
value=True,
info="Split large images into crops for better detail"
)
pan_scan_max_crops = gr.Slider(
minimum=1, maximum=6, value=DEFAULT_PAN_SCAN_MAX_CROPS, step=1,
label="Max Crops",
info="More crops = better detail but higher memory usage"
)
pan_scan_min_ratio = gr.Slider(
minimum=1.0, maximum=3.0, value=DEFAULT_PAN_SCAN_MIN_RATIO, step=0.1,
label="Min Ratio to Activate",
info="Aspect ratio threshold to trigger pan & scan"
)
gr.Markdown("""
### πŸ“‹ Instructions
**Usage:**
- Enter your question and click Send
- Upload multiple images as needed
- First request may take longer (model loading)
**Memory Tips:**
- Reduce "Max Crops" if you encounter memory errors
- Disable "Pan & Scan" for lower memory usage
- More images = higher memory consumption
### ⚠️ Disclaimer
Subject to Gemma license agreement.
""")
submit_btn.click(
add_text_and_files,
[chatbot, task_history, query, addfile_btn],
[chatbot, task_history, query, addfile_btn]
).then(
predict,
[chatbot, task_history, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio],
[chatbot],
show_progress="full"
)
empty_bin.click(reset_state, outputs=[chatbot, task_history, addfile_btn], show_progress=True)
regen_btn.click(
regenerate,
[chatbot, task_history, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio],
[chatbot],
show_progress="full"
)
query.submit(
add_text_and_files,
[chatbot, task_history, query, addfile_btn],
[chatbot, task_history, query, addfile_btn]
).then(
predict,
[chatbot, task_history, system_prompt, temperature, top_p, max_tokens,
do_pan_and_scan, pan_scan_max_crops, pan_scan_min_ratio],
[chatbot],
show_progress="full"
)
demo.queue(max_size=5).launch(
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
show_error=True
)
def main():
"""Main entry point"""
args = _get_args()
_launch_demo(args)
if __name__ == '__main__':
main()