| import gradio as gr |
| import spaces |
| from threading import Thread |
|
|
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration |
| from transformers import TextIteratorStreamer |
| from PIL import Image |
| from peft import PeftModel |
| import requests |
| import torch, os, re, json |
| import time |
|
|
|
|
| base_model = "llava-hf/llava-v1.6-mistral-7b-hf" |
| finetune_repo = "erwannd/llava-v1.6-mistral-7b-finetune-combined4k" |
|
|
| processor = LlavaNextProcessor.from_pretrained(base_model) |
|
|
| model = LlavaNextForConditionalGeneration.from_pretrained( |
| base_model, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True, |
| ) |
| model = PeftModel.from_pretrained(model, finetune_repo) |
| model.to("cuda:0") |
|
|
|
|
| @spaces.GPU |
| def predict(image, input_text): |
| image = image.convert("RGB") |
| prompt = f"[INST] <image>\n{input_text} [/INST]" |
| |
| inputs = processor(text=prompt, images=image, return_tensors="pt").to(0, torch.float16) |
| |
| streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True}) |
| |
|
|
| model.generate(**inputs, streamer=streamer, max_new_tokens=200, do_sample=False) |
|
|
| text_prompt = f"[INST] \n{input_text} [/INST]" |
|
|
| buffer = "" |
| time.sleep(0.5) |
| for new_text in streamer: |
| buffer += new_text |
| generated_text_without_prompt = buffer[len(text_prompt):] |
| time.sleep(0.04) |
| yield generated_text_without_prompt |
|
|
| |
| image = gr.components.Image(type="pil") |
| input_prompt = gr.components.Textbox(label="Input Prompt") |
| model_output = gr.components.Textbox(label="Model Output") |
| examples = [["./examples/bar_m01.png", "Evaluate and explain if this chart is misleading"], |
| ["./examples/bar_n01.png", "Is this chart misleading? Explain"], |
| ["./examples/fox_news_cropped.png", "Tell me if this chart is misleading"], |
| ["./examples/line_m01.png", "Explain if this chart is misleading"], |
| ["./examples/line_m04.png", "Evaluate and explain if this chart is misleading"], |
| ["./examples/pie_m01.png", "Evaluate if this chart is misleading, if so explain"], |
| ["./examples/pie_m02.png", "Is this chart misleading? Explain"]] |
|
|
|
|
| description_markdown = """Demo for [LlavaNext](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) finetuned on [charts dataset](https://huggingface.co/datasets/chart-misinformation-detection/bar_line_pie_4k)""" |
|
|
| title = "LlavaNext finetuned on Misleading Chart Dataset" |
| interface = gr.Interface( |
| fn=predict, |
| inputs=[image, input_prompt], |
| outputs=model_output, |
| examples=examples, |
| title=title, |
| theme='gradio/soft', |
| cache_examples=False, |
| description=description_markdown |
| ) |
|
|
| interface.launch() |