Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import base64 | |
| import tempfile | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| # ============================== | |
| # Configuration | |
| # ============================== | |
| HF_TOKEN = os.getenv("HF_TOKEN") # optional but recommended | |
| # Example Models (can be extended) | |
| MODEL_REGISTRY = { | |
| "Text - Mistral 7B Instruct": { | |
| "type": "text", | |
| "endpoint": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2" | |
| }, | |
| "Text - Llama 3 8B Instruct": { | |
| "type": "text", | |
| "endpoint": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" | |
| }, | |
| "Vision - LLaVA": { | |
| "type": "vision", | |
| "endpoint": "https://api-inference.huggingface.co/models/llava-hf/llava-1.5-7b-hf" | |
| }, | |
| "Audio - Whisper": { | |
| "type": "audio", | |
| "endpoint": "https://api-inference.huggingface.co/models/openai/whisper-large-v3" | |
| } | |
| } | |
| headers = { | |
| "Authorization": f"Bearer {HF_TOKEN}" if HF_TOKEN else None | |
| } | |
| # Helper Functions | |
| def query_text_model(endpoint, prompt): | |
| payload = {"inputs": prompt} | |
| response = requests.post(endpoint, headers=headers, json=payload) | |
| try: | |
| return response.json()[0]["generated_text"] | |
| except: | |
| return str(response.json()) | |
| def query_vision_model(endpoint, prompt, image): | |
| buffered = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| image.save(buffered.name) | |
| with open(buffered.name, "rb") as f: | |
| img_bytes = f.read() | |
| payload = { | |
| "inputs": { | |
| "image": base64.b64encode(img_bytes).decode("utf-8"), | |
| "text": prompt | |
| } | |
| } | |
| response = requests.post(endpoint, headers=headers, json=payload) | |
| return response.json() | |
| def query_audio_model(endpoint, audio_path): | |
| with open(audio_path, "rb") as f: | |
| data = f.read() | |
| response = requests.post(endpoint, headers=headers, data=data) | |
| return response.json() | |
| # Main Chat Function | |
| def multimodal_chat(prompt, image, audio, selected_models, history): | |
| outputs = {} | |
| for model_name in selected_models: | |
| model = MODEL_REGISTRY[model_name] | |
| try: | |
| if model["type"] == "text": | |
| result = query_text_model(model["endpoint"], prompt) | |
| elif model["type"] == "vision" and image is not None: | |
| result = query_vision_model(model["endpoint"], prompt, image) | |
| elif model["type"] == "audio" and audio is not None: | |
| result = query_audio_model(model["endpoint"], audio) | |
| else: | |
| result = "Unsupported input for this model" | |
| except Exception as e: | |
| result = f"Error: {str(e)}" | |
| outputs[model_name] = result | |
| history.append((prompt, outputs)) | |
| return history, "", None, None | |
| # UI | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Multimodal Model Comparison") as demo: | |
| gr.Markdown(""" | |
| # Multimodal Chat + Model Comparison | |
| Compare HuggingFace models across modalities | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot(height=500) | |
| prompt = gr.Textbox( | |
| placeholder="Enter your prompt...", | |
| label="Text Input" | |
| ) | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Image Input") | |
| audio_input = gr.Audio(type="filepath", label="Audio Input") | |
| submit = gr.Button("Send") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model Selection") | |
| model_selector = gr.CheckboxGroup( | |
| choices=list(MODEL_REGISTRY.keys()), | |
| value=["Text - Mistral 7B Instruct"], | |
| label="Select Models" | |
| ) | |
| clear = gr.Button("Clear") | |
| state = gr.State([]) | |
| submit.click( | |
| multimodal_chat, | |
| inputs=[prompt, image_input, audio_input, model_selector, state], | |
| outputs=[chatbot, prompt, image_input, audio_input] | |
| ) | |
| clear.click( | |
| lambda: [], | |
| None, | |
| chatbot | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() | |