rafaaa2105 commited on
Commit
4829745
Β·
verified Β·
1 Parent(s): 5f9eae1

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +341 -34
  2. chainlit.md +14 -0
app.py CHANGED
@@ -1,37 +1,344 @@
 
 
 
 
 
 
 
1
  import os
2
- import re
3
- import sys
4
- import asyncio
5
- import transformers
6
- from typing import Any, Dict, List, Union
7
- import torch
8
 
9
  import chainlit as cl
10
- from transformers import AutoTokenizer, AutoModelForCausalLM
11
-
12
- # Load the tokenizer and model
13
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7b-Instruct-v0.1")
14
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7b-Instruct-v0.1", torch_dtype=torch.float16)
15
-
16
- # Define a function to generate a response from the model
17
- async def generate_response(prompt: str) -> str:
18
- inputs = tokenizer(prompt, return_tensors="pt")
19
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
20
- outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id, temperature=0.7)
21
- response = tokenizer.decode(outputs[0])
22
- return response
23
-
24
- # Define the main function that runs the chatbot
25
- async def main(request: Dict[str, Any]) -> Dict[str, Union[str, List[str]]]:
26
- # Get the user's message
27
- user_message = request["message"]
28
-
29
- # Generate a response from the model
30
- response = await generate_response(user_message)
31
-
32
- # Return the response to the user
33
- return {"message": response}
34
-
35
- # Run the chatbot
36
- if __name__ == "__main__":
37
- cl.run(main, port=8000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import auto, Enum
2
+ import json
3
+ import dataclasses
4
+ from typing import List
5
+ import aiohttp
6
+ from PIL import Image
7
+ import io
8
  import os
 
 
 
 
 
 
9
 
10
  import chainlit as cl
11
+ from chainlit.input_widget import Select, Slider
12
+
13
+ CONTROLLER_URL = os.environ.get("LLAVA_CONTROLLER_URL")
14
+
15
+
16
+ class SeparatorStyle(Enum):
17
+ """Different separator style."""
18
+
19
+ SINGLE = auto()
20
+ TWO = auto()
21
+ MPT = auto()
22
+ PLAIN = auto()
23
+ LLAMA_2 = auto()
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class Conversation:
28
+ """A class that keeps all conversation history."""
29
+
30
+ system: str
31
+ roles: List[str]
32
+ messages: List[List[str]]
33
+ offset: int
34
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
35
+ sep: str = "###"
36
+ sep2: str = None
37
+ version: str = "Unknown"
38
+
39
+ skip_next: bool = False
40
+
41
+ def get_prompt(self):
42
+ messages = self.messages
43
+ if self.sep_style == SeparatorStyle.SINGLE:
44
+ ret = self.system + self.sep
45
+ for role, message in messages:
46
+ if message:
47
+ if type(message) is tuple:
48
+ message, _, _ = message
49
+ ret += role + ": " + message + self.sep
50
+ else:
51
+ ret += role + ":"
52
+ elif self.sep_style == SeparatorStyle.TWO:
53
+ seps = [self.sep, self.sep2]
54
+ ret = self.system + seps[0]
55
+ for i, (role, message) in enumerate(messages):
56
+ if message:
57
+ if type(message) is tuple:
58
+ message, _, _ = message
59
+ ret += role + ": " + message + seps[i % 2]
60
+ else:
61
+ ret += role + ":"
62
+ elif self.sep_style == SeparatorStyle.MPT:
63
+ ret = self.system + self.sep
64
+ for role, message in messages:
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += role + message + self.sep
69
+ else:
70
+ ret += role
71
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
72
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
73
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
74
+ ret = ""
75
+
76
+ for i, (role, message) in enumerate(messages):
77
+ if i == 0:
78
+ assert message, "first message should not be none"
79
+ assert role == self.roles[0], "first message should come from user"
80
+ if message:
81
+ if type(message) is tuple:
82
+ message, _, _ = message
83
+ if i == 0:
84
+ message = wrap_sys(self.system) + message
85
+ if i % 2 == 0:
86
+ message = wrap_inst(message)
87
+ ret += self.sep + message
88
+ else:
89
+ ret += " " + message + " " + self.sep2
90
+ else:
91
+ ret += ""
92
+ ret = ret.lstrip(self.sep)
93
+ elif self.sep_style == SeparatorStyle.PLAIN:
94
+ seps = [self.sep, self.sep2]
95
+ ret = self.system
96
+ for i, (role, message) in enumerate(messages):
97
+ if message:
98
+ if type(message) is tuple:
99
+ message, _, _ = message
100
+ ret += message + seps[i % 2]
101
+ else:
102
+ ret += ""
103
+ else:
104
+ raise ValueError(f"Invalid style: {self.sep_style}")
105
+
106
+ return ret
107
+
108
+ def append_message(self, role, message):
109
+ self.messages.append([role, message])
110
+
111
+ def get_images(self, return_pil=False):
112
+ images = []
113
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
114
+ if i % 2 == 0:
115
+ if type(msg) is tuple:
116
+ import base64
117
+ from io import BytesIO
118
+ from PIL import Image
119
+
120
+ msg, image, image_process_mode = msg
121
+ if image == None:
122
+ continue
123
+ if image_process_mode == "Pad":
124
+
125
+ def expand2square(pil_img, background_color=(122, 116, 104)):
126
+ width, height = pil_img.size
127
+ if width == height:
128
+ return pil_img
129
+ elif width > height:
130
+ result = Image.new(
131
+ pil_img.mode, (width, width), background_color
132
+ )
133
+ result.paste(pil_img, (0, (width - height) // 2))
134
+ return result
135
+ else:
136
+ result = Image.new(
137
+ pil_img.mode, (height, height), background_color
138
+ )
139
+ result.paste(pil_img, ((height - width) // 2, 0))
140
+ return result
141
+
142
+ image = expand2square(image)
143
+ elif image_process_mode in ["Default", "Crop"]:
144
+ pass
145
+ elif image_process_mode == "Resize":
146
+ image = image.resize((336, 336))
147
+ else:
148
+ raise ValueError(
149
+ f"Invalid image_process_mode: {image_process_mode}"
150
+ )
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if longest_edge != max(image.size):
158
+ if H > W:
159
+ H, W = longest_edge, shortest_edge
160
+ else:
161
+ H, W = shortest_edge, longest_edge
162
+ image = image.resize((W, H))
163
+ if return_pil:
164
+ images.append(image)
165
+ else:
166
+ buffered = BytesIO()
167
+ image.save(buffered, format="PNG")
168
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
169
+ images.append(img_b64_str)
170
+ return images
171
+
172
+ def copy(self):
173
+ return Conversation(
174
+ system=self.system,
175
+ roles=self.roles,
176
+ messages=[[x, y] for x, y in self.messages],
177
+ offset=self.offset,
178
+ sep_style=self.sep_style,
179
+ sep=self.sep,
180
+ sep2=self.sep2,
181
+ version=self.version,
182
+ )
183
+
184
+ def dict(self):
185
+ if len(self.get_images()) > 0:
186
+ return {
187
+ "system": self.system,
188
+ "roles": self.roles,
189
+ "messages": [
190
+ [x, y[0] if type(y) is tuple else y] for x, y in self.messages
191
+ ],
192
+ "offset": self.offset,
193
+ "sep": self.sep,
194
+ "sep2": self.sep2,
195
+ }
196
+ return {
197
+ "system": self.system,
198
+ "roles": self.roles,
199
+ "messages": self.messages,
200
+ "offset": self.offset,
201
+ "sep": self.sep,
202
+ "sep2": self.sep2,
203
+ }
204
+
205
+
206
+ default_conversation = Conversation(
207
+ system="A chat between a curious human and an artificial intelligence assistant. "
208
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
209
+ roles=("USER", "ASSISTANT"),
210
+ version="v1",
211
+ messages=(),
212
+ offset=0,
213
+ sep_style=SeparatorStyle.TWO,
214
+ sep=" ",
215
+ sep2="</s>",
216
+ )
217
+
218
+
219
+ headers = {"User-Agent": "LLaVA Client"}
220
+ image_process_mode = "Default"
221
+
222
+
223
+ async def request(conversation: Conversation, settings):
224
+ pload = {
225
+ "model": settings["model"],
226
+ "prompt": conversation.get_prompt(),
227
+ "temperature": settings["temperature"],
228
+ "top_p": settings["top_p"],
229
+ "max_new_tokens": int(settings["max_token"]),
230
+ "stop": conversation.sep
231
+ if conversation.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
232
+ else conversation.sep2,
233
+ }
234
+
235
+ pload["images"] = conversation.get_images()
236
+
237
+ async with aiohttp.ClientSession() as session:
238
+ async with session.post(
239
+ CONTROLLER_URL + "/worker_generate_stream",
240
+ headers=headers,
241
+ data=json.dumps(pload),
242
+ timeout=10,
243
+ ) as response:
244
+ chainlit_message = cl.Message(content="")
245
+ async for chunk in response.content.iter_any():
246
+ for json_str in chunk.decode().split("\0"):
247
+ if json_str:
248
+ data = json.loads(json_str)
249
+
250
+ if data["error_code"] == 0:
251
+ output = data["text"][len(pload["prompt"]) :].strip()
252
+ conversation.messages[-1][-1] = output + "β–Œ"
253
+ await chainlit_message.stream_token(
254
+ output, is_sequence=True
255
+ )
256
+ else:
257
+ output = (
258
+ data["text"] + f" (error_code: {data['error_code']})"
259
+ )
260
+ conversation.messages[-1][-1] = output
261
+ chainlit_message.content = output
262
+ await chainlit_message.send()
263
+ return conversation
264
+
265
+
266
+ @cl.on_chat_start
267
+ async def start():
268
+ settings = await cl.ChatSettings(
269
+ [
270
+ Select(
271
+ id="model",
272
+ label="Model",
273
+ values=["llava-v1.5-13b"],
274
+ initial_index=0,
275
+ ),
276
+ Slider(
277
+ id="temperature",
278
+ label="Temperature",
279
+ initial=0,
280
+ min=0,
281
+ max=1,
282
+ step=0.1,
283
+ ),
284
+ Slider(
285
+ id="top_p",
286
+ label="Top P",
287
+ initial=0.7,
288
+ min=0,
289
+ max=1,
290
+ step=0.1,
291
+ ),
292
+ Slider(
293
+ id="max_token",
294
+ label="Max output tokens",
295
+ initial=512,
296
+ min=0,
297
+ max=1024,
298
+ step=64,
299
+ ),
300
+ ]
301
+ ).send()
302
+
303
+ conversation = default_conversation.copy()
304
+
305
+ cl.user_session.set("conversation", conversation)
306
+ cl.user_session.set("settings", settings)
307
+
308
+
309
+ @cl.on_settings_update
310
+ async def setup_agent(settings):
311
+ cl.user_session.set("settings", settings)
312
+
313
+
314
+ @cl.on_message
315
+ async def main(message: cl.Message):
316
+ image = next(
317
+ (
318
+ Image.open(file.path)
319
+ for file in message.elements or []
320
+ if "image" in file.mime and file.path is not None
321
+ ),
322
+ None,
323
+ )
324
+
325
+ conv = cl.user_session.get("conversation") # type: Conversation
326
+ settings = cl.user_session.get("settings")
327
+
328
+ if image:
329
+ if len(conv.get_images(return_pil=True)) > 0:
330
+ # reset
331
+ conv = default_conversation.copy()
332
+ text = message.content[:1200]
333
+ if "<image>" not in text:
334
+ text = "<image>\n" + text
335
+ else:
336
+ text = message.content[:1536]
337
+
338
+ conv_message = (text, image, image_process_mode)
339
+ conv.append_message(conv.roles[0], conv_message)
340
+ conv.append_message(conv.roles[1], None)
341
+
342
+ conv = await request(conv, settings)
343
+
344
+ cl.user_session.set("conversation", conv)
chainlit.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to Chainlit! πŸš€πŸ€–
2
+
3
+ Hi there, Developer! πŸ‘‹ We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
4
+
5
+ ## Useful Links πŸ”—
6
+
7
+ - **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) πŸ“š
8
+ - **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! πŸ’¬
9
+
10
+ We can't wait to see what you create with Chainlit! Happy coding! πŸ’»πŸ˜Š
11
+
12
+ ## Welcome screen
13
+
14
+ To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.