infinity / app.py
GiantPandas's picture
Update app.py
9ca3781 verified
import gradio as gr
import time
import os
import sys
import re
import json
import base64
import tqdm
import base64
from datetime import datetime
import subprocess
import copy
from pathlib import Path
import numpy as np
from PIL import Image
from argparse import ArgumentParser
import tempfile
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def array_to_image_path(image_array):
if image_array is None:
raise ValueError("No image provided. Please upload an image before submitting.")
# Convert numpy array to PIL Image
img = Image.fromarray(np.uint8(image_array))
# Generate a unique filename using timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.png"
# Save the image
img.save(filename)
# Get the full path of the saved image
full_path = os.path.abspath(filename)
return full_path
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = os.environ.get("openai_api_base")
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
REVISION = 'v0.0.1'
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
def _get_args():
parser = ArgumentParser()
parser.add_argument("--revision", type=str, default=REVISION)
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--share", action="store_true", default=False,
help="Create a publicly shareable link for the interface.")
parser.add_argument("--inbrowser", action="store_true", default=False,
help="Automatically launch the interface in a new tab on the default browser.")
parser.add_argument("--server-port", type=int, default=7860,
help="Demo server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1",
help="Demo server name.")
args = parser.parse_args()
return args
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def _remove_image_special(text):
text = text.replace('<ref>', '').replace('</ref>', '')
return re.sub(r'<box>.*?(</box>|$)', '', text)
def is_video_file(filename):
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
return any(filename.lower().endswith(ext) for ext in video_extensions)
def _launch_demo(args):
uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
Path(tempfile.gettempdir()) / "gradio"
)
def predict(_chatbot, task_history):
print("chatbot", _chatbot)
print("history",task_history)
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print("User: " + _parse_text(query))
history_cp = copy.deepcopy(task_history)
full_response = ""
messages = []
content = []
for q, a in history_cp:
if isinstance(q, (tuple, list)):
if is_video_file(q[0]):
content.append({'video': f'file://{q[0]}'})
else:
content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encode_image(q[0])}"
}
})
else:
content.append({"type": "text", 'text': q})
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{"type": "text", 'text': a}]})
content = []
messages.pop()
responses = client.chat.completions.create(
model="Qwen2_5VL",
messages=messages,
extra_body={},
extra_headers={
"apikey": "empty"
},
stream=True,
temperature=0,
top_p=1.0,
)
response_text = []
for response in responses:
response = response.choices[0].delta.content
for ele in response:
response_text.append(ele)
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(''.join(response_text)))
yield _chatbot
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(''.join(response_text)))
full_response = _parse_text(''.join(response_text))
task_history[-1] = (query, full_response)
print("Qwen2.5-VL-Chat: " + _parse_text(full_response))
yield _chatbot
def regenerate(_chatbot, task_history):
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
_chatbot_gen = predict(_chatbot, task_history)
for _chatbot in _chatbot_gen:
yield _chatbot
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return [], []
with gr.Blocks() as demo:
gr.Markdown("""<center><font size=3> Infinity-7B Demo </center>""")
chatbot = gr.Chatbot(label='Infinity-7B', elem_classes="control-height", height=500)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
addfile_btn = gr.UploadButton("📁 Upload (上传文件)", file_types=["image", "video"])
submit_btn = gr.Button("🚀 Submit (发送)")
regen_btn = gr.Button("🤔️ Regenerate (重试)")
empty_bin = gr.Button("🧹 Clear History (清除历史)")
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot, task_history], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
demo.queue(default_concurrency_limit=40).launch(
share=args.share,
# inbrowser=args.inbrowser,
# server_port=args.server_port,
# server_name=args.server_name,
)
def main():
args = _get_args()
_launch_demo(args)
if __name__ == '__main__':
main()