Nightwalkx commited on
Commit
c545fd8
·
1 Parent(s): abd25d9

update app

Browse files
Files changed (1) hide show
  1. app.py +75 -92
app.py CHANGED
@@ -1,105 +1,88 @@
1
- import time
2
- from threading import Thread
3
-
4
  import gradio as gr
5
- import torch
 
 
 
6
  from PIL import Image
7
- from transformers import AutoProcessor, LlavaForConditionalGeneration
8
- from transformers import TextIteratorStreamer
9
-
10
  import spaces
11
 
 
 
12
 
13
- PLACEHOLDER = """
14
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
15
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64ccdc322e592905f922a06e/DDIW0kbWmdOQWwy4XMhwX.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
16
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaVA-Llama-3-8B</h1>
17
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Llava-Llama-3-8b is a LLaVA model fine-tuned from Meta-Llama-3-8B-Instruct and CLIP-ViT-Large-patch14-336 with ShareGPT4V-PT and InternVL-SFT by XTuner</p>
18
- </div>
19
- """
20
 
 
 
21
 
22
- model_id = "rogerxi/llava-finetune-test"
23
 
24
- processor = AutoProcessor.from_pretrained(model_id)
 
25
 
26
- model = LlavaForConditionalGeneration.from_pretrained(
27
- model_id,
28
- torch_dtype=torch.float16,
29
- low_cpu_mem_usage=True,
30
- )
31
 
32
- model.to("cuda:0")
33
- model.generation_config.eos_token_id = 128009
34
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- @spaces.GPU
37
- def bot_streaming(message, history):
38
- print(message)
39
- if message["files"]:
40
- # message["files"][-1] is a Dict or just a string
41
- if type(message["files"][-1]) == dict:
42
- image = message["files"][-1]["path"]
43
- else:
44
- image = message["files"][-1]
45
- else:
46
- # if there's no image uploaded for this turn, look for images in the past turns
47
- # kept inside tuples, take the last one
48
- for hist in history:
49
- if type(hist[0]) == tuple:
50
- image = hist[0][0]
51
- try:
52
- if image is None:
53
- # Handle the case where image is None
54
- gr.Error("You need to upload an image for LLaVA to work.")
55
- except NameError:
56
- # Handle the case where 'image' is not defined at all
57
- gr.Error("You need to upload an image for LLaVA to work.")
58
-
59
- prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
60
- # print(f"prompt: {prompt}")
61
- image = Image.open(image)
62
- inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
63
-
64
- streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
65
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
66
-
67
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
68
- thread.start()
69
-
70
- text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
71
- # print(f"text_prompt: {text_prompt}")
72
-
73
- buffer = ""
74
- time.sleep(0.5)
75
- for new_text in streamer:
76
- # find <|eot_id|> and remove it from the new_text
77
- if "<|eot_id|>" in new_text:
78
- new_text = new_text.split("<|eot_id|>")[0]
79
- buffer += new_text
80
-
81
- # generated_text_without_prompt = buffer[len(text_prompt):]
82
- generated_text_without_prompt = buffer
83
- # print(generated_text_without_prompt)
84
- time.sleep(0.06)
85
- # print(f"new_text: {generated_text_without_prompt}")
86
- yield generated_text_without_prompt
87
-
88
-
89
- chatbot=gr.Chatbot(placeholder=PLACEHOLDER,scale=1)
90
- chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
91
- with gr.Blocks(fill_height=True, ) as demo:
92
- gr.ChatInterface(
93
- fn=bot_streaming,
94
- title="LLaVA Llama-3-8B",
95
- examples=[{"text": "What is on the flower?", "files": ["./bee.jpg"]},
96
- {"text": "How to make this pastry?", "files": ["./baklava.png"]}],
97
- description="Try [LLaVA Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
98
- stop_btn="Stop Generation",
99
- multimodal=True,
100
- textbox=chat_input,
101
- chatbot=chatbot,
102
- )
103
-
104
- # demo.queue(api_open=False)
105
  demo.launch(debug=True)
 
 
 
 
1
  import gradio as gr
2
+ from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
3
+ from threading import Thread
4
+ import re
5
+ import time
6
  from PIL import Image
7
+ import torch
8
+ import cv2
 
9
  import spaces
10
 
11
+ # model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
12
+ model_id = "rogerxi/llava-finetune-test"
13
 
14
+ processor = LlavaProcessor.from_pretrained(model_id)
 
 
 
 
 
 
15
 
16
+ model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
17
+ model.to("cuda")
18
 
 
19
 
20
+ @spaces.GPU
21
+ def bot_streaming(message, history):
22
 
23
+ print(message)
24
+ txt = message['text']
 
 
 
25
 
26
+ ext_buffer = f"user\n{txt} assistant"
 
27
 
28
+ if message['files']:
29
+ if len(message['files']) == 1:
30
+ image = [message['files'][0]]
31
+ elif len(message['files']) > 1:
32
+ image = [msg.path for msg in message['files']]
33
+ else:
34
+ # if there's no image uploaded for this turn, look for images in the past turns
35
+ # kept inside tuples, take the last one
36
+ for hist in history:
37
+ if type(hist[0])==tuple:
38
+ image = hist[0][0]
39
 
40
+ if message['files'] is None:
41
+ gr.Error("You need to upload an image or video for LLaVA to work.")
42
+
43
+ image_extensions = Image.registered_extensions()
44
+ image_extensions = tuple([ex for ex, f in image_extensions.items()])
45
+ if len(image) == 1:
46
+ image = Image.open(image[0]).convert("RGB")
47
+ prompt = f"<|im_start|>user <image>\n{message['text']}<|im_end|><|im_start|>assistant"
48
+
49
+ elif len(image) > 1:
50
+ image_list = []
51
+ user_prompt = message['text']
52
+
53
+ for img in image:
54
+ img = Image.open(img).convert("RGB")
55
+ image_list.append(img)
56
+
57
+ toks = "<image>" * len(image_list)
58
+ prompt = "<|im_start|>user"+ toks + f"\n{user_prompt}<|im_end|><|im_start|>assistant"
59
+
60
+ image = image_list
61
+
62
+
63
+ inputs = processor(image, prompt, return_tensors="pt").to("cuda", torch.float16)
64
+ streamer = TextIteratorStreamer(processor, **{"max_new_tokens": 200, "skip_special_tokens": True})
65
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=100)
66
+ generated_text = ""
67
+
68
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
69
+ thread.start()
70
+
71
+
72
+
73
+ buffer = ""
74
+ for new_text in streamer:
75
+
76
+ buffer += new_text
77
+ print(buffer)
78
+ print(buffer[len(ext_buffer):])
79
+ generated_text_without_prompt = buffer[len(ext_buffer):]
80
+ time.sleep(0.01)
81
+ yield generated_text_without_prompt
82
+
83
+
84
+ demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA ",
85
+ textbox=gr.MultimodalTextbox(file_count="multiple"),
86
+ description="Try EgoLlava. If you don't upload an image, you will receive an error. ",
87
+ stop_btn="Stop Generation", multimodal=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  demo.launch(debug=True)