File size: 5,570 Bytes
547dee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
069261e
71021b9
547dee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import torch
import gradio as gr
from PIL import Image
from typing import List, Dict, Any
from transformers import AutoModel, AutoTokenizer

"""
Gradio app to run MiniCPM-V-4_5 int4 on CPU for image+text chat.
- Requires: pip install transformers accelerate gradio pillow
- Model: openbmb/MiniCPM-V-4_5-int4 (quantized, CPU-friendly)
- This script is self-contained and uses a simple multi-turn chat interface.
"""

MODEL_ID = os.environ.get("MINICPM_MODEL_ID", "openbmb/MiniCPM-V-4_5-int4")

# Global model/tokenizer, loaded once
model = None
tokenizer = None

def load_model():
    global model, tokenizer
    if model is not None and tokenizer is not None:
        return

    # For CPU inference, keep it simple and avoid .cuda() / bfloat16
    # trust_remote_code is required because MiniCPM implements custom .chat()
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    model = AutoModel.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        attn_implementation="sdpa",  # SDPA is fine on CPU; avoid flash-attn on CPU
        torch_dtype=torch.float32,   # Safer default for CPU
        device_map="cpu",             # Ensure CPU execution
        quantization_config=None,
    )
    model.eval()


def build_messages(history: List[Dict[str, Any]], image: Image.Image, user_input: str) -> List[Dict[str, Any]]:
    """
    Convert Gradio chat history + current inputs into the message format expected by MiniCPM's .chat().
    history: List of {"role": "user"/"assistant", "content": "..."} pairs (text-only transcript).
    image: PIL.Image or None for the current turn.
    user_input: current user text.
    Returns a msgs list with roles and content arrays [image?, text].
    """
    msgs = []
    # Reconstruct multi-turn context: interleave user/assistant turns
    # We assume each user message is text-only and assistant reply is text-only in history.
    # For the current turn, we can attach an image (if provided) and the user's text.
    for turn in history:
        # Each turn in history is a tuple (user_text, assistant_text) from gr.Chatbot
        user_text, assistant_text = turn
        if user_text is not None:
            msgs.append({"role": "user", "content": [user_text]})
        if assistant_text is not None:
            msgs.append({"role": "assistant", "content": [assistant_text]})

    # Append current user turn (with optional image)
    content = []
    if image is not None:
        # Ensure RGB
        if image.mode != "RGB":
            image = image.convert("RGB")
        content.append(image)
    if user_input and user_input.strip():
        content.append(user_input.strip())
    else:
        # Ensure there is at least something in the content
        content.append("")

    msgs.append({"role": "user", "content": content})
    return msgs


def respond(user_text: str, image: Image.Image, chat_history: List[List[str]], enable_thinking: bool):
    """
    Inference handler for Gradio. Returns updated chat history and clears the user textbox.
    """
    load_model()

    # Build MiniCPM messages
    msgs = build_messages(chat_history or [], image, user_text)

    # Run model.chat
    with torch.inference_mode():
        answer = model.chat(
            msgs=msgs,
            tokenizer=tokenizer,
            enable_thinking=enable_thinking
        )

    # Update history shown in Chatbot: append (user_text, answer)
    # If user_text is empty but image provided, show a placeholder text.
    shown_user_msg = user_text.strip() if (user_text and user_text.strip()) else "[Image]"
    chat_history = chat_history + [[shown_user_msg, answer]]
    return chat_history, ""


def clear_history():
    return [], None, ""


def demo_app():
    with gr.Blocks(title="MiniCPM-V-4_5-int4 (CPU) - Gradio", theme="soft") as demo:
        gr.Markdown("## MiniCPM-V-4_5-int4 (CPU) Demo\nUpload an image (optional) and ask a question.")
        with gr.Row():
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(height=420, type="messages", avatar_images=(None, None))
                with gr.Row():
                    img = gr.Image(type="pil", label="Image (optional)", height=240)
                user_in = gr.Textbox(
                    label="Your message",
                    placeholder="Ask something about the image or chat without an image...",
                    lines=3
                )
                with gr.Row():
                    enable_thinking = gr.Checkbox(value=False, label="Enable thinking mode")
                    send_btn = gr.Button("Send", variant="primary")
                    clear_btn = gr.Button("Clear")

            with gr.Column(scale=1):
                gr.Markdown("### Model")
                gr.Markdown(f"- ID: `{MODEL_ID}`\n- Device: CPU\n- Quant: int4")

        # Events
        send_btn.click(
            fn=respond,
            inputs=[user_in, img, chatbot, enable_thinking],
            outputs=[chatbot, user_in]
        )
        user_in.submit(
            fn=respond,
            inputs=[user_in, img, chatbot, enable_thinking],
            outputs=[chatbot, user_in]
        )
        clear_btn.click(
            fn=clear_history,
            inputs=[],
            outputs=[chatbot, img, user_in]
        )

    return demo


if __name__ == "__main__":
    # Make sure we don't accidentally spawn CUDA context
    os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
    demo = demo_app()
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))