Youssefk commited on
Commit
e2226d2
·
1 Parent(s): 55e9032
Files changed (1) hide show
  1. st +63 -0
st ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_chat import message
3
+ import requests
4
+ from transformers import AutoModelWithLMHead, AutoTokenizer
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')
7
+ model = AutoModelWithLMHead.from_pretrained('output-small-save')
8
+
9
+ st.set_page_config(
10
+ page_title="COVID Doctor using DialoGPT",
11
+ page_icon=":robot:"
12
+ )
13
+
14
+ API_URL = "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
15
+ #headers = {"Authorization": st.secrets['api_key']}
16
+
17
+ st.header("Hello - Welcome to COVID Doctor using DialoGPT")
18
+ st.markdown("[Github](https://github.com/rushic24/DialoGPT-Finetune)")
19
+
20
+ if 'generated' not in st.session_state:
21
+ st.session_state['generated'] = []
22
+
23
+ if 'past' not in st.session_state:
24
+ st.session_state['past'] = []
25
+
26
+ def query(payload):
27
+ bot_input_ids = tokenizer.encode(payload["inputs"]["text"] + tokenizer.eos_token, return_tensors='pt')
28
+
29
+ chat_history_ids = model.generate(
30
+ bot_input_ids, max_length=100,
31
+ pad_token_id=tokenizer.eos_token_id,
32
+ no_repeat_ngram_size=3,
33
+ do_sample=True,
34
+ top_k=10,
35
+ top_p=0.7,
36
+ temperature = 0.8
37
+ )
38
+ output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
39
+ return {"generated_text": output}
40
+
41
+ def get_text():
42
+ input_text = st.text_input("You: ","I have shortness of breath and are worried, I don’t have a cough or sore throat, so they will not test me, should I do a private test?", key="input")
43
+ return input_text
44
+
45
+
46
+ user_input = get_text()
47
+
48
+ if user_input:
49
+ output = query({
50
+ "inputs": {
51
+ "past_user_inputs": st.session_state.past,
52
+ "generated_responses": st.session_state.generated,
53
+ "text": user_input,
54
+ },"parameters": {"repetition_penalty": 1.33},
55
+ })
56
+ st.session_state.past.append(user_input)
57
+ st.session_state.generated.append(output["generated_text"])
58
+
59
+ if st.session_state['generated']:
60
+
61
+ for i in range(len(st.session_state['generated'])-1, -1, -1):
62
+ message(st.session_state["generated"][i], key=str(i))
63
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')