chameleon / app.py
darknoon's picture
misc, idk why this space isn't showing up!
85c421c
raw
history blame
2.78 kB
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, ChameleonProcessor, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
from PIL import Image
import requests
model_path = "facebook/chameleon-7b"
# model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
# processor = ChameleonProcessor.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto", use_auth_token=True)
model.eval()
processor = ChameleonProcessor.from_pretrained(model_path, use_auth_token=True)
tokenizer = processor.tokenizer
def load_example_image():
global image
if not image:
image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)
return image
@spaces.GPU(duration=90)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# messages = [{"role": "system", "content": system_message}]
# for val in history:
# if val[0]:
# messages.append({"role": "user", "content": val[0]})
# if val[1]:
# messages.append({"role": "assistant", "content": val[1]})
# messages.append({"role": "user", "content": message})
response = ""
prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
image = load_example_image()
inputs = processor(prompt, images=[image], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
# launch generation in the background
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
yield partial_message
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()