abhiimanyu commited on
Commit
02664ae
·
verified ·
1 Parent(s): 0e769a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import spaces
4
+ import torch
5
+ import gradio as gr
6
+ from threading import Thread
7
+
8
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
9
+
10
+ TITLE = "<h1><center>Mistral-lab</center></h1>"
11
+
12
+ PLACEHOLDER = """
13
+ <center>
14
+ <p>Chat with Mistral AI LLM.</p>
15
+ </center>
16
+ """
17
+
18
+ from huggingface_hub import snapshot_download
19
+ from pathlib import Path
20
+
21
+ mistral_models_path = Path.home().joinpath('mistral_models', '8B-Instruct')
22
+ mistral_models_path.mkdir(parents=True, exist_ok=True)
23
+
24
+ snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path)
25
+
26
+ from mistral_inference.transformer import Transformer
27
+ from mistral_inference.generate import generate
28
+
29
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
30
+ from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
31
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
32
+
33
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
34
+
35
+ tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
36
+ model = Transformer.from_folder(mistral_models_path)
37
+
38
+
39
+ @spaces.GPU()
40
+ def stream_chat(
41
+ message: str,
42
+ history: list,
43
+ temperature: float = 0.3,
44
+ max_new_tokens: int = 1024,
45
+ ):
46
+ print(f'message: {message}')
47
+ print(f'history: {history}')
48
+
49
+ conversation = []
50
+ for prompt, answer in history:
51
+ conversation.append(UserMessage(content=prompt))
52
+ conversation.append(AssistantMessage(content=answer))
53
+ conversation.append(UserMessage(content=message))
54
+
55
+ completion_request = ChatCompletionRequest(messages=conversation)
56
+
57
+ tokens = tokenizer.encode_chat_completion(completion_request).tokens
58
+
59
+ out_tokens, _ = generate(
60
+ [tokens],
61
+ model,
62
+ max_tokens=max_new_tokens,
63
+ temperature=temperature,
64
+ eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
65
+
66
+ result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
67
+
68
+ return result
69
+
70
+ chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
71
+
72
+ with gr.Blocks(theme="ocean") as demo:
73
+ gr.HTML(TITLE)
74
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
75
+ gr.ChatInterface(
76
+ fn=stream_chat,
77
+ chatbot=chatbot,
78
+ fill_height=True,
79
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
80
+ additional_inputs=[
81
+ gr.Slider(
82
+ minimum=0,
83
+ maximum=1,
84
+ step=0.1,
85
+ value=0.3,
86
+ label="Temperature",
87
+ render=False,
88
+ ),
89
+ gr.Slider(
90
+ minimum=128,
91
+ maximum=8192,
92
+ step=1,
93
+ value=1024,
94
+ label="Max new tokens",
95
+ render=False,
96
+ ),
97
+ ],
98
+ examples=[
99
+ ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
100
+ ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
101
+ ["Tell me a random fun fact about the Roman Empire."],
102
+ ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
103
+ ],
104
+ cache_examples=False,
105
+ )
106
+
107
+
108
+ if __name__ == "__main__":
109
+ demo.launch()