Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import spaces | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers.utils import move_cache | |
| from huggingface_hub import snapshot_download | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B | |
| MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B" | |
| # https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B-int4 | |
| # MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B-int4" | |
| ### DOWNLOAD ### | |
| os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' | |
| MODEL_PATH = snapshot_download(MODEL_PATH) | |
| move_cache() | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 | |
| ## MODEL ## | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True | |
| ) | |
| ## TOKENIZER ## | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=TORCH_TYPE, | |
| trust_remote_code=True, | |
| ).to(DEVICE).eval() | |
| text_only_template = """USER: {} ASSISTANT:""" | |
| def generate_caption(image, prompt): | |
| print(DEVICE) | |
| # Process the image and the prompt | |
| # image = Image.open(image_path).convert('RGB') | |
| image = image.convert('RGB') | |
| query = "USER: %s ASSISTANT:" % prompt | |
| input_by_model = model.build_conversation_input_ids( | |
| tokenizer, | |
| query=query, | |
| history=[], | |
| images=[image], | |
| template_version='chat' | |
| ) | |
| inputs = { | |
| 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), | |
| 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), | |
| 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), | |
| 'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None, | |
| } | |
| gen_kwargs = { | |
| "max_new_tokens": 2048, | |
| "pad_token_id": 128002, | |
| } | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, **gen_kwargs) | |
| outputs = outputs[:, inputs['input_ids'].shape[1]:] | |
| response = tokenizer.decode(outputs[0]) | |
| response = response.split("<|end_of_text|>")[0] | |
| print("\nCogVLM2:", response) | |
| return response | |
| ## make predictions via api ## | |
| # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app | |
| demo = gr.Interface( | |
| fn=generate_caption, | |
| inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")], | |
| outputs=gr.Textbox(label="Generated Caption") | |
| ) | |
| # Launch the interface | |
| demo.launch(share=True) |