Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import spaces | |
| import torch | |
| import gradio as gr | |
| from threading import Thread | |
| from huggingface_hub import snapshot_download | |
| from pathlib import Path | |
| from mistral_inference.transformer import Transformer | |
| from mistral_inference.generate import generate | |
| from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
| from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage | |
| from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| PLACEHOLDER = """ | |
| <center> | |
| <p>Chat with Mistral AI LLM.</p> | |
| </center> | |
| """ | |
| CSS = """ | |
| .duplicate-button { | |
| margin: auto !important; | |
| color: white !important; | |
| background: black !important; | |
| border-radius: 100vh !important; | |
| } | |
| h3 { | |
| text-align: center; | |
| } | |
| """ | |
| # download model | |
| mistral_models_path = Path.home().joinpath('mistral_models', '8B-Instruct') | |
| mistral_models_path.mkdir(parents=True, exist_ok=True) | |
| snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path) | |
| # tokenizer | |
| device = "cuda" if torch.cuda.is_available() else "cpu" # for GPU usage or "cpu" for CPU usage | |
| tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json") | |
| model = Transformer.from_folder( | |
| mistral_models_path, | |
| device=device, | |
| dtype=torch.bfloat16) | |
| def stream_chat( | |
| message: str, | |
| history: list, | |
| temperature: float = 0.3, | |
| max_new_tokens: int = 1024, | |
| ): | |
| print(f'message: {message}') | |
| print(f'history: {history}') | |
| conversation = [] | |
| for prompt, answer in history: | |
| conversation.append(UserMessage(content=prompt)) | |
| conversation.append(AssistantMessage(content=answer)) | |
| # for item in history: | |
| # if item[role] == "user": | |
| # conversation.append(UserMessage(content=item[content])) | |
| # elif item[role] == "assistant": | |
| # conversation.append(AssistantMessage(content=item[content])) | |
| conversation.append(UserMessage(content=message)) | |
| print(f'history: {conversation}') | |
| completion_request = ChatCompletionRequest(messages=conversation) | |
| tokens = tokenizer.encode_chat_completion(completion_request).tokens | |
| out_tokens, _ = generate( | |
| [tokens], | |
| model, | |
| max_tokens=max_new_tokens, | |
| temperature=temperature, | |
| eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id) | |
| result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0]) | |
| for i in range(len(result)): | |
| time.sleep(0.05) | |
| yield result[: i + 1] | |
| chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER) | |
| with gr.Blocks(theme="citrus", css=CSS) as demo: | |
| gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
| gr.ChatInterface( | |
| fn=stream_chat, | |
| title="Mistral-lab", | |
| chatbot=chatbot, | |
| # type="messages", | |
| fill_height=True, | |
| examples=[ | |
| ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."], | |
| ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."], | |
| ["Tell me a random fun fact about the Roman Empire."], | |
| ["Show me a code snippet of a website's sticky header in CSS and JavaScript."], | |
| ], | |
| cache_examples = False, | |
| additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
| additional_inputs=[ | |
| gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.3, | |
| label="Temperature", | |
| render=False, | |
| ), | |
| gr.Slider( | |
| minimum=128, | |
| maximum=8192, | |
| step=1, | |
| value=1024, | |
| label="Max new tokens", | |
| render=False, | |
| ), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |