File size: 6,757 Bytes
88a5b8d 02b26ea 4aff560 02b26ea 4aff560 88a5b8d 4aff560 88a5b8d 4aff560 d0df48e 02b26ea d0df48e 02b26ea 88a5b8d 02b26ea d0df48e 4aff560 02b26ea d0df48e 4aff560 02b26ea d0df48e 02b26ea 8178d81 88a5b8d 4aff560 02b26ea 52eb57f 02b26ea d0df48e 02b26ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
import gradio as gr
import base64
import time
import html
from huggingface_hub import InferenceClient
def progress_bar_html(label: str) -> str:
"""
Returns an HTML snippet for a thin progress bar with a label.
The progress bar is styled as a dark animated bar.
"""
return f"""
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
"""
model_name = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
def model_inference(input_dict, history, *additional_inputs):
"""
Use Hugging Face InferenceClient (streaming) to perform the multimodal chat completion.
Signature matches ChatInterface call pattern: (input_dict, history, *additional_inputs)
The OAuth token (from gr.LoginButton) is passed as `hf_token`.
"""
# Extract hf_token from additional_inputs in a robust way (gradio sometimes passes extra args)
hf_token = None
for ai in additional_inputs:
if ai is None:
continue
# gradio may pass a small object with attribute `token`
if hasattr(ai, "token"):
hf_token = ai
break
# or a dict-like with a token key
if isinstance(ai, dict) and "token" in ai:
class _T:
pass
obj = _T()
obj.token = ai.get("token")
hf_token = obj
break
# or the token itself could be passed as a string
if isinstance(ai, str):
class _T2:
pass
obj = _T2()
obj.token = ai
hf_token = obj
break
text = input_dict.get("text", "")
files = input_dict.get("files", []) or []
if text == "" and not files:
# yield an error text so the streaming generator produces at least one value
yield "Please input a query and optionally image(s)."
return
if text == "" and files:
yield "Please input a text query along with the image(s)."
return
# Build the content list: images (as URLs or data URLs) followed by the text
content_list = []
for f in files:
try:
# If file looks like a URL, send as image_url
if isinstance(f, str) and f.startswith("http"):
content_list.append({"type": "image_url", "image_url": {"url": f}})
else:
# f is a local path-like object; read and convert to base64 data url
with open(f, "rb") as fh:
b = fh.read()
b64 = base64.b64encode(b).decode("utf-8")
# naive mime type: jpeg; this should work for most common images
data_url = f"data:image/jpeg;base64,{b64}"
content_list.append(
{"type": "image_url", "image_url": {"url": data_url}}
)
except Exception:
# if anything goes wrong reading the file, skip embedding that file
continue
content_list.append({"type": "text", "text": text})
messages = [{"role": "user", "content": content_list}]
if hf_token is None or not getattr(hf_token, "token", None):
yield "Please login with a Hugging Face account (use the Login button in the sidebar)."
return
client = InferenceClient(
token=hf_token.token, model=model_name, provider="hf-inference"
)
response = ""
for message in client.chat_completion(
messages,
max_tokens=1024,
stream=True,
):
choices = message.choices
token = ""
if len(choices) and choices[0].delta.content:
token = choices[0].delta.content
response += token
yield response
# for chunk in stream:
# # chunk can be an object with attributes or a dict depending on client version
# token = ""
# try:
# # attempt dict-style
# if isinstance(chunk, dict):
# choices = chunk.get("choices")
# if choices and len(choices) > 0:
# delta = choices[0].get("delta", {})
# token = delta.get("content") or ""
# else:
# # attribute-style
# choices = getattr(chunk, "choices", None)
# if choices and len(choices) > 0:
# delta = getattr(choices[0], "delta", None)
# if isinstance(delta, dict):
# token = delta.get("content") or ""
# else:
# token = getattr(delta, "content", "")
# except Exception:
# token = ""
# if token:
# # escape incremental token to avoid raw HTML breaking the chat box
# response += html.escape(token)
# time.sleep(0.001)
# yield response
# # ensure we yield at least one final message so the async iterator doesn't see StopIteration
# if response:
# yield response
# else:
# yield "(no text was returned by the model)"
examples = [
[
{
"text": "Write a descriptive caption for this image in a formal tone.",
"files": ["example_images/example.png"],
}
],
[
{
"text": "What are the characters wearing?",
"files": ["example_images/example.png"],
}
],
]
with gr.Blocks() as demo:
with gr.Sidebar():
# Gradio LoginButton may not accept a `label` kwarg depending on the installed version
# so create it without that argument for maximum compatibility.
login_btn = gr.LoginButton()
chatbot = gr.ChatInterface(
fn=model_inference,
description="# **Smolvlm2-500M-illustration-description** \n (running on CPU) The model only sees the last input, it ignores the previous conversation history.",
examples=examples,
fill_height=True,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
additional_inputs=[login_btn],
)
# ChatInterface is already created inside the Blocks context; calling render() can duplicate it
# so we avoid calling chatbot.render() here.
if __name__ == "__main__":
demo.launch(debug=True)
|