File size: 4,176 Bytes
c4f1d5b
e7b948c
c4f1d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32feaab
c4f1d5b
32feaab
c4f1d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32feaab
c4f1d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32feaab
c4f1d5b
32feaab
c4f1d5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7b948c
32feaab
c4f1d5b
32feaab
e7b948c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

import gradio as gr
import requests
import base64
import tempfile
import os
from PIL import Image
import numpy as np

# ==============================
# Configuration
# ==============================

HF_TOKEN = os.getenv("HF_TOKEN")  # optional but recommended

# Example Models (can be extended)
MODEL_REGISTRY = {
    "Text - Mistral 7B Instruct": {
        "type": "text",
        "endpoint": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
    },
    "Text - Llama 3 8B Instruct": {
        "type": "text",
        "endpoint": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
    },
    "Vision - LLaVA": {
        "type": "vision",
        "endpoint": "https://api-inference.huggingface.co/models/llava-hf/llava-1.5-7b-hf"
    },
    "Audio - Whisper": {
        "type": "audio",
        "endpoint": "https://api-inference.huggingface.co/models/openai/whisper-large-v3"
    }
}

headers = {
    "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else None
}



# Helper Functions


def query_text_model(endpoint, prompt):
    payload = {"inputs": prompt}
    response = requests.post(endpoint, headers=headers, json=payload)
    try:
        return response.json()[0]["generated_text"]
    except:
        return str(response.json())


def query_vision_model(endpoint, prompt, image):
    buffered = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    image.save(buffered.name)

    with open(buffered.name, "rb") as f:
        img_bytes = f.read()

    payload = {
        "inputs": {
            "image": base64.b64encode(img_bytes).decode("utf-8"),
            "text": prompt
        }
    }

    response = requests.post(endpoint, headers=headers, json=payload)
    return response.json()


def query_audio_model(endpoint, audio_path):
    with open(audio_path, "rb") as f:
        data = f.read()

    response = requests.post(endpoint, headers=headers, data=data)
    return response.json()


# Main Chat Function


def multimodal_chat(prompt, image, audio, selected_models, history):
    outputs = {}

    for model_name in selected_models:
        model = MODEL_REGISTRY[model_name]

        try:
            if model["type"] == "text":
                result = query_text_model(model["endpoint"], prompt)

            elif model["type"] == "vision" and image is not None:
                result = query_vision_model(model["endpoint"], prompt, image)

            elif model["type"] == "audio" and audio is not None:
                result = query_audio_model(model["endpoint"], audio)

            else:
                result = "Unsupported input for this model"

        except Exception as e:
            result = f"Error: {str(e)}"

        outputs[model_name] = result

    history.append((prompt, outputs))

    return history, "", None, None



# UI


with gr.Blocks(theme=gr.themes.Soft(), title="Multimodal Model Comparison") as demo:

    gr.Markdown("""
    # Multimodal Chat + Model Comparison
    Compare HuggingFace models across modalities
    """)

    with gr.Row():

        with gr.Column(scale=3):
            chatbot = gr.Chatbot(height=500)

            prompt = gr.Textbox(
                placeholder="Enter your prompt...",
                label="Text Input"
            )

            with gr.Row():
                image_input = gr.Image(type="pil", label="Image Input")
                audio_input = gr.Audio(type="filepath", label="Audio Input")

            submit = gr.Button("Send")

        with gr.Column(scale=1):
            gr.Markdown("### Model Selection")

            model_selector = gr.CheckboxGroup(
                choices=list(MODEL_REGISTRY.keys()),
                value=["Text - Mistral 7B Instruct"],
                label="Select Models"
            )

            clear = gr.Button("Clear")

    state = gr.State([])

    submit.click(
        multimodal_chat,
        inputs=[prompt, image_input, audio_input, model_selector, state],
        outputs=[chatbot, prompt, image_input, audio_input]
    )

    clear.click(
        lambda: [],
        None,
        chatbot
    )



# Launch


if __name__ == "__main__":
    demo.launch()