Spaces:
Sleeping
Sleeping
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") | |
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| import re | |
| import solara | |
| from typing import List | |
| from typing_extensions import TypedDict | |
| class MessageDict(TypedDict): | |
| role: str | |
| content: str | |
| def response_generator(message): | |
| text = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": message}], | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| for chunk in streamer: | |
| yield chunk | |
| def add_chunk_to_ai_message(chunk: str): | |
| messages.value = [ | |
| *messages.value[:-1], | |
| { | |
| "role": "assistant", | |
| "content": messages.value[-1]["content"] + chunk, | |
| }, | |
| ] | |
| messages: solara.Reactive[List[MessageDict]] = solara.reactive([]) | |
| def Page(): | |
| solara.lab.theme.themes.light.primary = "#0000ff" | |
| solara.lab.theme.themes.light.secondary = "#0000ff" | |
| solara.lab.theme.themes.dark.primary = "#0000ff" | |
| solara.lab.theme.themes.dark.secondary = "#0000ff" | |
| title = "Qwen2-0.5B-Instruct" | |
| with solara.Head(): | |
| solara.Title(f"{title}") | |
| with solara.Column(align="center"): | |
| user_message_count = len([m for m in messages.value if m["role"] == "user"]) | |
| def send(message): | |
| messages.value = [*messages.value, {"role": "user", "content": message}] | |
| def response(message): | |
| messages.value = [*messages.value, {"role": "assistant", "content": ""}] | |
| for chunk in response_generator(message): | |
| add_chunk_to_ai_message(chunk) | |
| def result(): | |
| if messages.value != []: | |
| response(messages.value[-1]["content"]) | |
| result = solara.lab.use_task(result, dependencies=[user_message_count]) | |
| with solara.lab.ChatBox(style={"position": "fixed", "overflow-y": "scroll","scrollbar-width": "none", "-ms-overflow-style": "none", "top": "0", "bottom": "10rem", "width": "70%"}): | |
| for item in messages.value: | |
| with solara.lab.ChatMessage( | |
| user=item["role"] == "user", | |
| name="User" if item["role"] == "user" else "Qwen2-0.5B-Instruct", | |
| avatar_background_color="#33cccc" if item["role"] == "assistant" else "#ff991f", | |
| border_radius="20px", | |
| style="background-color:darkgrey!important;" if solara.lab.theme.dark_effective else "background-color:lightgrey!important;" | |
| ): | |
| item["content"] = re.sub('<\|im_end\|>', '', item["content"]) | |
| solara.Markdown(item["content"]) | |
| solara.lab.ChatInput(send_callback=send, style={"position": "fixed", "bottom": "3rem", "width": "70%"}) | |