rafaaa2105 commited on
Commit
7380387
·
verified ·
1 Parent(s): 4829745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -343
app.py CHANGED
@@ -1,344 +1,25 @@
1
- from enum import auto, Enum
2
- import json
3
- import dataclasses
4
- from typing import List
5
- import aiohttp
6
- from PIL import Image
7
- import io
8
- import os
9
-
10
  import chainlit as cl
11
- from chainlit.input_widget import Select, Slider
12
-
13
- CONTROLLER_URL = os.environ.get("LLAVA_CONTROLLER_URL")
14
-
15
-
16
- class SeparatorStyle(Enum):
17
- """Different separator style."""
18
-
19
- SINGLE = auto()
20
- TWO = auto()
21
- MPT = auto()
22
- PLAIN = auto()
23
- LLAMA_2 = auto()
24
-
25
-
26
- @dataclasses.dataclass
27
- class Conversation:
28
- """A class that keeps all conversation history."""
29
-
30
- system: str
31
- roles: List[str]
32
- messages: List[List[str]]
33
- offset: int
34
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
35
- sep: str = "###"
36
- sep2: str = None
37
- version: str = "Unknown"
38
-
39
- skip_next: bool = False
40
-
41
- def get_prompt(self):
42
- messages = self.messages
43
- if self.sep_style == SeparatorStyle.SINGLE:
44
- ret = self.system + self.sep
45
- for role, message in messages:
46
- if message:
47
- if type(message) is tuple:
48
- message, _, _ = message
49
- ret += role + ": " + message + self.sep
50
- else:
51
- ret += role + ":"
52
- elif self.sep_style == SeparatorStyle.TWO:
53
- seps = [self.sep, self.sep2]
54
- ret = self.system + seps[0]
55
- for i, (role, message) in enumerate(messages):
56
- if message:
57
- if type(message) is tuple:
58
- message, _, _ = message
59
- ret += role + ": " + message + seps[i % 2]
60
- else:
61
- ret += role + ":"
62
- elif self.sep_style == SeparatorStyle.MPT:
63
- ret = self.system + self.sep
64
- for role, message in messages:
65
- if message:
66
- if type(message) is tuple:
67
- message, _, _ = message
68
- ret += role + message + self.sep
69
- else:
70
- ret += role
71
- elif self.sep_style == SeparatorStyle.LLAMA_2:
72
- wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
73
- wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
74
- ret = ""
75
-
76
- for i, (role, message) in enumerate(messages):
77
- if i == 0:
78
- assert message, "first message should not be none"
79
- assert role == self.roles[0], "first message should come from user"
80
- if message:
81
- if type(message) is tuple:
82
- message, _, _ = message
83
- if i == 0:
84
- message = wrap_sys(self.system) + message
85
- if i % 2 == 0:
86
- message = wrap_inst(message)
87
- ret += self.sep + message
88
- else:
89
- ret += " " + message + " " + self.sep2
90
- else:
91
- ret += ""
92
- ret = ret.lstrip(self.sep)
93
- elif self.sep_style == SeparatorStyle.PLAIN:
94
- seps = [self.sep, self.sep2]
95
- ret = self.system
96
- for i, (role, message) in enumerate(messages):
97
- if message:
98
- if type(message) is tuple:
99
- message, _, _ = message
100
- ret += message + seps[i % 2]
101
- else:
102
- ret += ""
103
- else:
104
- raise ValueError(f"Invalid style: {self.sep_style}")
105
-
106
- return ret
107
-
108
- def append_message(self, role, message):
109
- self.messages.append([role, message])
110
-
111
- def get_images(self, return_pil=False):
112
- images = []
113
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
114
- if i % 2 == 0:
115
- if type(msg) is tuple:
116
- import base64
117
- from io import BytesIO
118
- from PIL import Image
119
-
120
- msg, image, image_process_mode = msg
121
- if image == None:
122
- continue
123
- if image_process_mode == "Pad":
124
-
125
- def expand2square(pil_img, background_color=(122, 116, 104)):
126
- width, height = pil_img.size
127
- if width == height:
128
- return pil_img
129
- elif width > height:
130
- result = Image.new(
131
- pil_img.mode, (width, width), background_color
132
- )
133
- result.paste(pil_img, (0, (width - height) // 2))
134
- return result
135
- else:
136
- result = Image.new(
137
- pil_img.mode, (height, height), background_color
138
- )
139
- result.paste(pil_img, ((height - width) // 2, 0))
140
- return result
141
-
142
- image = expand2square(image)
143
- elif image_process_mode in ["Default", "Crop"]:
144
- pass
145
- elif image_process_mode == "Resize":
146
- image = image.resize((336, 336))
147
- else:
148
- raise ValueError(
149
- f"Invalid image_process_mode: {image_process_mode}"
150
- )
151
- max_hw, min_hw = max(image.size), min(image.size)
152
- aspect_ratio = max_hw / min_hw
153
- max_len, min_len = 800, 400
154
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
- longest_edge = int(shortest_edge * aspect_ratio)
156
- W, H = image.size
157
- if longest_edge != max(image.size):
158
- if H > W:
159
- H, W = longest_edge, shortest_edge
160
- else:
161
- H, W = shortest_edge, longest_edge
162
- image = image.resize((W, H))
163
- if return_pil:
164
- images.append(image)
165
- else:
166
- buffered = BytesIO()
167
- image.save(buffered, format="PNG")
168
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
169
- images.append(img_b64_str)
170
- return images
171
-
172
- def copy(self):
173
- return Conversation(
174
- system=self.system,
175
- roles=self.roles,
176
- messages=[[x, y] for x, y in self.messages],
177
- offset=self.offset,
178
- sep_style=self.sep_style,
179
- sep=self.sep,
180
- sep2=self.sep2,
181
- version=self.version,
182
- )
183
-
184
- def dict(self):
185
- if len(self.get_images()) > 0:
186
- return {
187
- "system": self.system,
188
- "roles": self.roles,
189
- "messages": [
190
- [x, y[0] if type(y) is tuple else y] for x, y in self.messages
191
- ],
192
- "offset": self.offset,
193
- "sep": self.sep,
194
- "sep2": self.sep2,
195
- }
196
- return {
197
- "system": self.system,
198
- "roles": self.roles,
199
- "messages": self.messages,
200
- "offset": self.offset,
201
- "sep": self.sep,
202
- "sep2": self.sep2,
203
- }
204
-
205
-
206
- default_conversation = Conversation(
207
- system="A chat between a curious human and an artificial intelligence assistant. "
208
- "The assistant gives helpful, detailed, and polite answers to the human's questions.",
209
- roles=("USER", "ASSISTANT"),
210
- version="v1",
211
- messages=(),
212
- offset=0,
213
- sep_style=SeparatorStyle.TWO,
214
- sep=" ",
215
- sep2="</s>",
216
- )
217
-
218
-
219
- headers = {"User-Agent": "LLaVA Client"}
220
- image_process_mode = "Default"
221
-
222
-
223
- async def request(conversation: Conversation, settings):
224
- pload = {
225
- "model": settings["model"],
226
- "prompt": conversation.get_prompt(),
227
- "temperature": settings["temperature"],
228
- "top_p": settings["top_p"],
229
- "max_new_tokens": int(settings["max_token"]),
230
- "stop": conversation.sep
231
- if conversation.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
232
- else conversation.sep2,
233
- }
234
-
235
- pload["images"] = conversation.get_images()
236
-
237
- async with aiohttp.ClientSession() as session:
238
- async with session.post(
239
- CONTROLLER_URL + "/worker_generate_stream",
240
- headers=headers,
241
- data=json.dumps(pload),
242
- timeout=10,
243
- ) as response:
244
- chainlit_message = cl.Message(content="")
245
- async for chunk in response.content.iter_any():
246
- for json_str in chunk.decode().split("\0"):
247
- if json_str:
248
- data = json.loads(json_str)
249
-
250
- if data["error_code"] == 0:
251
- output = data["text"][len(pload["prompt"]) :].strip()
252
- conversation.messages[-1][-1] = output + "▌"
253
- await chainlit_message.stream_token(
254
- output, is_sequence=True
255
- )
256
- else:
257
- output = (
258
- data["text"] + f" (error_code: {data['error_code']})"
259
- )
260
- conversation.messages[-1][-1] = output
261
- chainlit_message.content = output
262
- await chainlit_message.send()
263
- return conversation
264
-
265
-
266
- @cl.on_chat_start
267
- async def start():
268
- settings = await cl.ChatSettings(
269
- [
270
- Select(
271
- id="model",
272
- label="Model",
273
- values=["llava-v1.5-13b"],
274
- initial_index=0,
275
- ),
276
- Slider(
277
- id="temperature",
278
- label="Temperature",
279
- initial=0,
280
- min=0,
281
- max=1,
282
- step=0.1,
283
- ),
284
- Slider(
285
- id="top_p",
286
- label="Top P",
287
- initial=0.7,
288
- min=0,
289
- max=1,
290
- step=0.1,
291
- ),
292
- Slider(
293
- id="max_token",
294
- label="Max output tokens",
295
- initial=512,
296
- min=0,
297
- max=1024,
298
- step=64,
299
- ),
300
- ]
301
- ).send()
302
-
303
- conversation = default_conversation.copy()
304
-
305
- cl.user_session.set("conversation", conversation)
306
- cl.user_session.set("settings", settings)
307
-
308
-
309
- @cl.on_settings_update
310
- async def setup_agent(settings):
311
- cl.user_session.set("settings", settings)
312
-
313
-
314
- @cl.on_message
315
- async def main(message: cl.Message):
316
- image = next(
317
- (
318
- Image.open(file.path)
319
- for file in message.elements or []
320
- if "image" in file.mime and file.path is not None
321
- ),
322
- None,
323
- )
324
-
325
- conv = cl.user_session.get("conversation") # type: Conversation
326
- settings = cl.user_session.get("settings")
327
-
328
- if image:
329
- if len(conv.get_images(return_pil=True)) > 0:
330
- # reset
331
- conv = default_conversation.copy()
332
- text = message.content[:1200]
333
- if "<image>" not in text:
334
- text = "<image>\n" + text
335
- else:
336
- text = message.content[:1536]
337
-
338
- conv_message = (text, image, image_process_mode)
339
- conv.append_message(conv.roles[0], conv_message)
340
- conv.append_message(conv.roles[1], None)
341
-
342
- conv = await request(conv, settings)
343
-
344
- cl.user_session.set("conversation", conv)
 
 
 
 
 
 
 
 
 
 
1
  import chainlit as cl
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
5
+ model = AutoModelForCausalLM.from_pretrained(model_name)
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+
8
+ def generate_response(prompt):
9
+ model_inputs = tokenizer([prompt], return_tensors="pt")
10
+ generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
11
+ response = tokenizer.batch_decode(generated_ids)[0]
12
+ return response
13
+
14
+ @cl.langchain_factory
15
+ def factory():
16
+ from langchain.chains import ConversationChain
17
+ from langchain.memory import ConversationBufferMemory
18
+ from langchain.llms import HuggingFacePipeline
19
+
20
+ hf_pipeline = HuggingFacePipeline(pipeline=generate_response)
21
+ memory = ConversationBufferMemory()
22
+ chain = ConversationChain(llm=hf_pipeline, memory=memory)
23
+ return chain
24
+
25
+ cl.Chatbot(factory).launch(share=True)