Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import torch | |
| torch.jit.script = lambda f: f | |
| import shlex | |
| import spaces | |
| import gradio as gr | |
| from threading import Thread | |
| from transformers import TextIteratorStreamer | |
| import hashlib | |
| import os | |
| from transformers import AutoModel, AutoProcessor | |
| import sys | |
| import subprocess | |
| from PIL import Image | |
| import time | |
| # install packages for mamba | |
| def install(): | |
| print("Install personal packages", flush=True) | |
| subprocess.run(shlex.split("pip install causal_conv1d-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) | |
| subprocess.run(shlex.split("pip install mamba_ssm-1.2.0.post1-cp310-cp310-linux_x86_64.whl")) | |
| install() | |
| from cobra import load | |
| vlm = load("cobra+3b") | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| DTYPE = torch.bfloat16 | |
| else: | |
| DEVICE = "cpu" | |
| DTYPE = torch.float32 | |
| vlm.to(DEVICE, dtype=DTYPE) | |
| prompt_builder = vlm.get_prompt_builder() | |
| def bot_streaming(message, history, temperature, top_k, max_new_tokens): | |
| streamer = TextIteratorStreamer(vlm.llm_backbone.tokenizer, skip_special_tokens=True) | |
| if len(history) == 0: | |
| prompt_builder.prompt, prompt_builder.turn_count = "", 0 | |
| image = None | |
| if message["files"]: | |
| image = message["files"][-1]["path"] | |
| else: | |
| # if there's no image uploaded for this turn, look for images in the past turns | |
| # kept inside tuples, take the last one | |
| for hist in history: | |
| if type(hist[0])==tuple: | |
| image = hist[0][0] | |
| if image is not None: | |
| image = Image.open(image).convert("RGB") | |
| prompt_builder.add_turn(role="human", message=message['text']) | |
| prompt_text = prompt_builder.get_prompt() | |
| generation_kwargs = { | |
| "image": image, | |
| "prompt_text": prompt_text, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "use_cache": True, | |
| "temperature": temperature, | |
| "do_sample": True, | |
| "top_k": top_k, | |
| } | |
| # Generate from the VLM | |
| thread = Thread(target=vlm.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| buffer = "" | |
| output_started = False | |
| for new_text in streamer: | |
| if not output_started: | |
| if "<|assistant|>\n" in new_text: | |
| output_started = True | |
| continue | |
| buffer += new_text | |
| if len(buffer) > 1: | |
| yield buffer | |
| prompt_builder.add_turn(role="gpt", message=buffer) | |
| return buffer | |
| demo = gr.ChatInterface(fn=bot_streaming, | |
| additional_inputs=[gr.Slider(0, 1, value=0.2, label="Temperature"), | |
| gr.Slider(1, 3, value=1, step=1, label="Top k"), | |
| gr.Slider(1, 2048, value=256, step=1, label="Max New Tokens")], | |
| title="Cobra", | |
| description="Try [Cobra](https://huggingface.co/papers/2403.14520) in this demo. Upload an image and start chatting about it. Clear the history before asking questions related to new images", | |
| stop_btn="Stop Generation", multimodal=True) | |
| demo.launch(debug=True) |