Spaces:
Running
on
Zero
Running
on
Zero
陈硕
commited on
Commit
·
e7b0784
1
Parent(s):
538c34c
add vlm
Browse files
app.py
CHANGED
|
@@ -28,6 +28,8 @@ from diffusers import (
|
|
| 28 |
)
|
| 29 |
from diffusers.utils import load_video, load_image
|
| 30 |
from datetime import datetime, timedelta
|
|
|
|
|
|
|
| 31 |
|
| 32 |
from diffusers.image_processor import VaeImageProcessor
|
| 33 |
from openai import OpenAI
|
|
@@ -171,55 +173,73 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
|
|
| 171 |
return temp_video_path
|
| 172 |
|
| 173 |
|
| 174 |
-
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
return prompt
|
| 222 |
|
|
|
|
| 223 |
@spaces.GPU
|
| 224 |
def infer(
|
| 225 |
prompt: str,
|
|
@@ -323,11 +343,11 @@ with gr.Blocks() as demo:
|
|
| 323 |
""")
|
| 324 |
with gr.Row():
|
| 325 |
with gr.Column():
|
| 326 |
-
image_in = gr.Image(label="Image
|
| 327 |
examples_component_images = gr.Examples(examples_images, inputs=[image_in], cache_examples=False)
|
| 328 |
# prompt = gr.Textbox(label="Prompt")
|
| 329 |
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
|
| 330 |
-
submit_btn = gr.Button("Submit")
|
| 331 |
|
| 332 |
# with gr.Column():
|
| 333 |
# with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
|
@@ -341,9 +361,9 @@ with gr.Blocks() as demo:
|
|
| 341 |
|
| 342 |
with gr.Row():
|
| 343 |
gr.Markdown(
|
| 344 |
-
"✨Upon pressing the enhanced prompt button, we will use [
|
| 345 |
)
|
| 346 |
-
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
| 347 |
with gr.Group():
|
| 348 |
with gr.Column():
|
| 349 |
with gr.Row():
|
|
|
|
| 28 |
)
|
| 29 |
from diffusers.utils import load_video, load_image
|
| 30 |
from datetime import datetime, timedelta
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from transformers import AutoModelForCausalLM, LlamaTokenizer
|
| 33 |
|
| 34 |
from diffusers.image_processor import VaeImageProcessor
|
| 35 |
from openai import OpenAI
|
|
|
|
| 173 |
return temp_video_path
|
| 174 |
|
| 175 |
|
| 176 |
+
def convert_prompt(prompt: str, image_path: str = None, retry_times: int = 3) -> str:
|
| 177 |
+
# Define model and tokenizer paths
|
| 178 |
+
MODEL_PATH = "THUDM/cogagent-chat-hf"
|
| 179 |
+
TOKENIZER_PATH = "lmsys/vicuna-7b-v1.5"
|
| 180 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 181 |
+
torch_type = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 182 |
+
|
| 183 |
+
# Initialize model and tokenizer
|
| 184 |
+
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
|
| 185 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 186 |
+
MODEL_PATH,
|
| 187 |
+
torch_dtype=torch_type,
|
| 188 |
+
low_cpu_mem_usage=True,
|
| 189 |
+
trust_remote_code=True
|
| 190 |
+
).to(DEVICE).eval()
|
| 191 |
+
|
| 192 |
+
# Conversation template for text-only queries
|
| 193 |
+
text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
|
| 194 |
+
|
| 195 |
+
# Check if image is available
|
| 196 |
+
if image_path and os.path.isfile(image_path):
|
| 197 |
+
image = Image.open(image_path).convert('RGB')
|
| 198 |
+
else:
|
| 199 |
+
image = None
|
| 200 |
+
|
| 201 |
+
# Initialize history for conversation context
|
| 202 |
+
history = []
|
| 203 |
+
query = prompt.strip()
|
| 204 |
+
|
| 205 |
+
for _ in range(retry_times):
|
| 206 |
+
if image is None:
|
| 207 |
+
# Text-only query, format as required by CogAgent
|
| 208 |
+
query = text_only_template.format(query)
|
| 209 |
+
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base')
|
| 210 |
+
inputs = {
|
| 211 |
+
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
|
| 212 |
+
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
|
| 213 |
+
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE)
|
| 214 |
+
}
|
| 215 |
+
else:
|
| 216 |
+
# Image-based input with initial query
|
| 217 |
+
input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
|
| 218 |
+
inputs = {
|
| 219 |
+
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
|
| 220 |
+
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
|
| 221 |
+
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
|
| 222 |
+
'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]]
|
| 223 |
+
}
|
| 224 |
+
if 'cross_images' in input_by_model and input_by_model['cross_images']:
|
| 225 |
+
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
|
| 226 |
+
|
| 227 |
+
# Generation settings
|
| 228 |
+
gen_kwargs = {"max_length": 2048, "do_sample": False}
|
| 229 |
+
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
outputs = model.generate(**inputs, **gen_kwargs)
|
| 232 |
+
outputs = outputs[:, inputs['input_ids'].shape[1]:]
|
| 233 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 234 |
+
response = response.split("</s>")[0].strip() # Clean up response
|
| 235 |
+
|
| 236 |
+
if response:
|
| 237 |
+
return response # Return the response if generated successfully
|
| 238 |
+
|
| 239 |
+
# Return original prompt if all retries fail
|
| 240 |
return prompt
|
| 241 |
|
| 242 |
+
|
| 243 |
@spaces.GPU
|
| 244 |
def infer(
|
| 245 |
prompt: str,
|
|
|
|
| 343 |
""")
|
| 344 |
with gr.Row():
|
| 345 |
with gr.Column():
|
| 346 |
+
image_in = gr.Image(label="Input Image (will be cropped to 720 * 480)")
|
| 347 |
examples_component_images = gr.Examples(examples_images, inputs=[image_in], cache_examples=False)
|
| 348 |
# prompt = gr.Textbox(label="Prompt")
|
| 349 |
orbit_type = gr.Radio(label="Orbit type", choices=["Left", "Up"], value="Left", interactive=True)
|
| 350 |
+
# submit_btn = gr.Button("Submit")
|
| 351 |
|
| 352 |
# with gr.Column():
|
| 353 |
# with gr.Accordion("I2V: Image Input (cannot be used simultaneously with video input)", open=False):
|
|
|
|
| 361 |
|
| 362 |
with gr.Row():
|
| 363 |
gr.Markdown(
|
| 364 |
+
"✨Upon pressing the enhanced prompt button, we will use [CogVLM](https://github.com/THUDM/CogVLM) to polish the prompt and overwrite the original one."
|
| 365 |
)
|
| 366 |
+
enhance_button = gr.Button("✨ Enhance Prompt(Optional but highly recommend)")
|
| 367 |
with gr.Group():
|
| 368 |
with gr.Column():
|
| 369 |
with gr.Row():
|