am commited on
Commit
895f657
·
1 Parent(s): 651be4c
app.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer
3
+ from transformers.image_utils import load_image
4
+ from transformers.image_transforms import resize
5
+ from threading import Thread
6
+ import re
7
+ import time
8
+ import torch
9
+ import spaces
10
+ import math
11
+ import os
12
+
13
+
14
+
15
+ # pretrained_model_name_or_path="amrn/testmodel"
16
+
17
+ pretrained_model_name_or_path=os.environ.get("MODEL", "amrn/testmodel")
18
+
19
+ auth_token = os.environ.get("HF_TOKEN") or True
20
+
21
+
22
+ processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path,
23
+ use_fast=True,
24
+ #trust_remote_code=True
25
+ )
26
+
27
+
28
+ model = AutoModelForImageTextToText.from_pretrained(
29
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
30
+ torch_dtype=torch.bfloat16,
31
+ # attn_implementation="flash_attention_2",
32
+ # trust_remote_code=True,
33
+ token=auth_token
34
+ ).eval().to("cuda")
35
+
36
+
37
+ @spaces.GPU
38
+ def model_inference(
39
+ input_dict, history
40
+ ):
41
+
42
+ print(f"input_dict: {input_dict}")
43
+ print(f"history: {history}")
44
+
45
+ text = input_dict["text"]
46
+
47
+ if len(history) > 0:
48
+ try:
49
+ image = history[0]['content'][0]
50
+ except:
51
+ raise gr.Error("Please refresh the page to start over.")
52
+
53
+ else:
54
+ try:
55
+ image = input_dict["files"][0]
56
+ except:
57
+ raise gr.Error("Please provide an image.", duration=2)
58
+
59
+ if len(text) == 0:
60
+ raise gr.Error("Please input a query.", duration=2)
61
+
62
+ if len(image) == 0:
63
+ raise gr.Error("Please provide an image.", duration=2)
64
+
65
+ image = load_image(image)
66
+
67
+ resulting_messages=[]
68
+
69
+ if len(history) > 0:
70
+ for i in range(1, len(history)):
71
+ h = history[i]
72
+ resulting_messages.append({
73
+ "role": h['role'],
74
+ "content": [{"type": "text", "text": h['content']}]
75
+ })
76
+
77
+ # latest
78
+ resulting_messages.append({
79
+ "role": "user",
80
+ "content": [{"type": "text", "text": text}]
81
+ })
82
+ resulting_messages[0]['content'].append({"type": "image"})
83
+
84
+
85
+ print(f"resulting_messages: {resulting_messages}")
86
+ print(f"image0: {image} size: {image.size}")
87
+
88
+
89
+ width, height = image.size
90
+ max_pixels = 512*512
91
+ if height * width > max_pixels:
92
+ beta = math.sqrt((height * width) / max_pixels)
93
+ h_bar = math.floor(height / beta)
94
+ w_bar = math.floor(width / beta)
95
+ image = image.resize((w_bar, h_bar))
96
+ print(f"resizedimage: {image} size: {image.size}")
97
+
98
+
99
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
100
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
101
+ inputs = inputs.to('cuda')
102
+
103
+
104
+ # Generate
105
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
106
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=2048)
107
+ generated_text = ""
108
+
109
+ thread = Thread(target=model.generate, kwargs=generation_args)
110
+ thread.start()
111
+
112
+ yield "..."
113
+ buffer = ""
114
+
115
+
116
+ for new_text in streamer:
117
+ buffer += new_text
118
+ # generated_text_without_prompt = buffer#[len(ext_buffer):]
119
+ # time.sleep(0.01)
120
+ # print(f"buffer: {buffer}")
121
+ yield buffer
122
+
123
+
124
+ examples=[
125
+ [{"text": "Find abnormalities and support devices.", "files": ["example_images/35.jpg"]}],
126
+ [{"text": "Find abnormalities and support devices.", "files": ["example_images/363.jpg"]}],
127
+ [{"text": "Find abnormalities and support devices.", "files": ["example_images/376.jpg"]}],
128
+
129
+ ]
130
+
131
+
132
+ demo = gr.ChatInterface(fn=model_inference,
133
+ chatbot=gr.Chatbot(type="messages", render_markdown=True, sanitize_html=False, allow_tags=True, height=640, min_height=640, max_height=640, resizable=False),
134
+ type="messages",
135
+ title="Demo",
136
+ description="Demo.",
137
+ examples=examples,
138
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="single", lines=1, max_lines=4), stop_btn=True, multimodal=True,
139
+ cache_examples=False,
140
+ fill_height=False
141
+ # flagging_mode="manual",
142
+ )
143
+
144
+
145
+
146
+
147
+ demo.launch(debug=False, server_name="0.0.0.0")
148
+
example_images/35.jpg ADDED
example_images/363.jpg ADDED
example_images/376.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ huggingface_hub
5
+ gradio
6
+ spaces
7
+
8
+ # accelerate
9
+ # flash-attn --no-build-isolation
10
+ # numpy
11
+ # Pillow
12
+ # requests
13
+ # pydantic==2.10.6