File size: 4,992 Bytes
6bced81 6b01bf3 6bced81 6b01bf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import cv2
import gradio as gr
from llm import LLM
import llm as llm
import argparse
import socket
parser = argparse.ArgumentParser(description="Model configuration parameters")
parser.add_argument("--hf_model", type=str, default="./InternVL3-2B",
help="Path to HuggingFace model")
parser.add_argument("--axmodel_path", type=str, default="./InternVL3-2B_axmodel",
help="Path to save compiled axmodel of llama model")
parser.add_argument("--vit_model", type=str, default=None,
help="Path to save compiled axmodel of llama model")
args = parser.parse_args()
hf_model_path = args.hf_model
axmodel_path = args.axmodel_path
vit_axmodel_path = args.vit_model
gllm = LLM(hf_model_path, axmodel_path, vit_axmodel_path)
# 获取本地 IP 地址
def get_local_ip():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
ip = s.getsockname()[0]
s.close()
return ip
except Exception:
return "127.0.0.1"
def stop_generation():
gllm.stop_generate()
def respond(prompt, video, image, is_image, video_segments, image_segments_cols, image_segments_rows, history=None):
if history is None:
history = []
if not prompt.strip():
return history
# append empty response to history
gllm.tag = "video" if not is_image else "image"
history.append((prompt, ""))
yield history
print(video)
print(image)
if is_image:
img = cv2.imread(image)
images_list = []
if image_segments_cols == 1 and image_segments_rows == 1:
images_list.append(img)
elif image_segments_cols * image_segments_rows > 8:
# gr.Error("image segments cols * image segments rows > 8")
history[-1] = (prompt, history[-1][1] + "image segments cols * image segments rows > 8")
yield history
return
else:
height, width, _ = img.shape
segment_width = width // image_segments_cols
segment_height = height // image_segments_rows
for i in range(image_segments_rows):
for j in range(image_segments_cols):
x1 = j * segment_width
y1 = i * segment_height
x2 = (j + 1) * segment_width
y2 = (i + 1) * segment_height
segment = img[y1:y2, x1:x2]
images_list.append(segment)
else:
images_list = llm.load_video_opencv(video, num_segments = video_segments)
for msg in gllm.generate(images_list, prompt):
print(msg, end="", flush=True)
history[-1] = (prompt, history[-1][1] + msg)
yield history
print("\n\n\n")
def chat_interface():
with gr.Blocks() as demo:
gr.Markdown("## InternVL3 Chat DEMO \nUpload an image or video and chat with the InternVL3.")
with gr.Row():
with gr.Column(scale=1):
video = gr.Video(label="Upload Video", format="mp4")
video_segments = gr.Slider(minimum=2, maximum=8, step=1, value=4, label="video segments")
image = gr.Image(label="Upload Image", type="filepath")
image_segments_cols = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image cols segments")
image_segments_rows = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="image rows segments")
checkbox = gr.Checkbox(label="Use Image")
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=650)
prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这组图片")
with gr.Row():
btn_chat = gr.Button("Chat", variant="primary")
btn_stop = gr.Button("Stop", variant="stop")
btn_stop.click(fn=stop_generation, inputs=None, outputs=None)
btn_chat.click(
fn=respond,
inputs=[prompt, video, image, checkbox, video_segments, image_segments_cols, image_segments_rows, chatbot],
outputs=chatbot
)
def on_video_uploaded(video):
if video is not None:
return gr.update(value=False)
return gr.update()
def on_image_uploaded(image):
if image is not None:
return gr.update(value=True)
return gr.update()
video.change(fn=on_video_uploaded, inputs=video, outputs=checkbox)
image.change(fn=on_image_uploaded, inputs=image, outputs=checkbox)
local_ip = get_local_ip()
server_port = 7860
print(f"HTTP 服务地址: http://{local_ip}:{server_port}")
demo.launch(server_name=local_ip, server_port=server_port)
if __name__ == "__main__":
chat_interface()
|