Thang commited on
Commit
7051c9e
·
1 Parent(s): d015c2d

Complete API

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. inference.py +45 -0
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- # from inference import *
3
 
4
 
5
  st.title("💬 Chatbot")
@@ -16,8 +16,8 @@ if prompt := st.chat_input():
16
  st.session_state.messages.append({"role": "user", "content": prompt})
17
  st.chat_message("user").write(prompt)
18
 
19
- response = "" #client.chat.completions.create(model="gpt-3.5-turbo", messages=st.session_state.messages)
20
- msg = "" #response.choices[0].message.content
21
 
22
  st.session_state.messages.append({"role": "assistant", "content": msg})
23
  st.chat_message("assistant").write(msg)
 
1
  import streamlit as st
2
+ from inference import *
3
 
4
 
5
  st.title("💬 Chatbot")
 
16
  st.session_state.messages.append({"role": "user", "content": prompt})
17
  st.chat_message("user").write(prompt)
18
 
19
+ response = generate_text(st.session_state.messages) #client.chat.completions.create(model="gpt-3.5-turbo", messages=st.session_state.messages)
20
+ msg = response #response.choices[0].message.content
21
 
22
  st.session_state.messages.append({"role": "assistant", "content": msg})
23
  st.chat_message("assistant").write(msg)
inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
2
+ import torch
3
+
4
+ model_name = "mistralai/Mistral-7B-Instruct-v0.2"
5
+
6
+ bnb_config = BitsAndBytesConfig(
7
+ load_in_4bit=True,
8
+ bnb_4bit_quant_type="nf4",
9
+ bnb_4bit_use_double_quant=True,
10
+ )
11
+
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype=torch.bfloat16,
17
+ trust_remote_code=True,
18
+ device_map="auto",
19
+ low_cpu_mem_usage=True,
20
+ # load_in_4bit = True,
21
+ quantization_config = bnb_config
22
+ )
23
+
24
+
25
+
26
+ def generate_text(messages):
27
+
28
+ encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
29
+ no_token_encodeds = tokenizer.apply_chat_template(messages, tokenize=False).replace('<s>', "").replace('</s>', "")
30
+
31
+ output = model.generate(
32
+ encodeds,
33
+ max_length=200,
34
+ do_sample=True,
35
+ top_k=10,
36
+ num_return_sequences=1,
37
+ eos_token_id=tokenizer.eos_token_id,
38
+ )
39
+
40
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
41
+ return output_text[len(no_token_encodeds) + 2:]
42
+
43
+ # # Remove Prompt Echo from Generated Text
44
+ # cleaned_output_text = output_text.replace(input_text, "")
45
+ # return cleaned_output_text