File size: 4,887 Bytes
111a99e
 
 
2fcfad9
111a99e
9f9c33b
2fcfad9
 
111a99e
2fcfad9
b38e046
 
111a99e
 
b38e046
9f9c33b
 
2fcfad9
 
 
 
 
 
 
9f9c33b
 
 
 
 
 
 
 
 
2fcfad9
9f9c33b
 
2fcfad9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f9c33b
111a99e
 
 
 
 
 
 
 
9f9c33b
111a99e
 
 
 
 
 
 
 
9f9c33b
111a99e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f9c33b
a608f20
2fcfad9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111a99e
2fcfad9
 
111a99e
9f9c33b
 
 
111a99e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python3
import os
import json
import base64
import requests
import gradio as gr
from PIL import Image
from io import BytesIO

# Get environment variables from HF Spaces secrets
ENDPOINT = os.environ.get("VLLM_ENDPOINT")
MODEL = os.environ.get("VLLM_MODEL")

if not ENDPOINT or not MODEL:
    raise ValueError("VLLM_ENDPOINT and VLLM_MODEL environment variables must be set. Please add them as secrets in your Space settings.")


def image_to_base64(image):
    """Convert PIL Image to base64 string."""
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """
    Send messages (with optional images) to vLLM endpoint and stream the response.
    """
    messages = [{"role": "system", "content": system_message}]
    
    # Add conversation history
    for msg in history:
        messages.append(msg)
    
    # Process the current message - check if it contains an image
    if message and "files" in message and message["files"]:
        # Message has image(s)
        content = []
        
        # Add text if present
        if message.get("text", "").strip():
            content.append({"type": "text", "text": message["text"]})
        
        # Add all images
        for file_info in message["files"]:
            try:
                image = Image.open(file_info)
                b64_image = image_to_base64(image)
                content.append({
                    "type": "image_url",
                    "image_url": {"url": f"data:image/png;base64,{b64_image}"}
                })
            except Exception as e:
                print(f"Error processing image: {e}")
        
        messages.append({"role": "user", "content": content})
    else:
        # Text-only message
        text_content = message if isinstance(message, str) else message.get("text", "")
        messages.append({"role": "user", "content": text_content})

    payload = {
        "model": MODEL,
        "messages": messages,
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "stream": True
    }

    try:
        response = requests.post(
            ENDPOINT,
            headers={"Content-Type": "application/json"},
            data=json.dumps(payload),
            stream=True
        )
        response.raise_for_status()

        accumulated_response = ""
        
        for line in response.iter_lines():
            if line:
                line = line.decode('utf-8')
                if line.startswith('data: '):
                    line = line[6:]  # Remove 'data: ' prefix
                    
                if line.strip() == '[DONE]':
                    break
                    
                try:
                    chunk = json.loads(line)
                    if 'choices' in chunk and len(chunk['choices']) > 0:
                        delta = chunk['choices'][0].get('delta', {})
                        content = delta.get('content', '')
                        if content:
                            accumulated_response += content
                            yield accumulated_response
                except json.JSONDecodeError:
                    continue
                    
    except Exception as e:
        yield f"Error: {str(e)}"


# Build the Gradio Interface
with gr.Blocks(title="πŸ’¬ Vision Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # πŸ’¬ Vision-Enabled Chat Interface
        **πŸ’‘ How to use:**
        1. Type your message in the chat box
        2. Optionally upload images by clicking the πŸ“Ž icon
        3. Adjust parameters in the accordion below if needed
        4. Press Enter or click Send
        
        The model can understand both text and images!
        """
    )
    
    chatbot = gr.ChatInterface(
        respond,
        type="messages",
        multimodal=True,
        additional_inputs=[
            gr.Textbox(
                value="You are a helpful AI assistant with vision capabilities. You can understand and analyze images.",
                label="System message"
            ),
            gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
            gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.95,
                step=0.05,
                label="Top-p (nucleus sampling)",
            ),
        ],
    )
    
    chatbot.render()
    
    gr.Markdown("""
    ---
    **Note:** Configure endpoint via `VLLM_ENDPOINT` and `VLLM_MODEL` environment variables.
    """)


if __name__ == "__main__":
    demo.launch()