ez7051 commited on
Commit
5c41295
·
verified ·
1 Parent(s): c92f595

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.output_parsers import StrOutputParser
2
+ from langchain_core.prompts import ChatPromptTemplate
3
+ from langchain_core.runnables import RunnableBranch
4
+ from langchain_core.runnables.passthrough import RunnableAssign
5
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
6
+
7
+ import gradio as gr
8
+
9
+ embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type="query")
10
+ chat_model = ChatNVIDIA(model="llama2_13b") | StrOutputParser()
11
+
12
+ response_prompt = ChatPromptTemplate.from_messages([("system", "{system}"), ("user", "{input}")])
13
+
14
+ def RPrint(preface=""):
15
+ def print_and_return(x, preface=""):
16
+ print(f"{preface}{x}")
17
+ return x
18
+ return RunnableLambda(partial(print_and_return, preface=preface))
19
+
20
+ ## "Help them out" system message
21
+ good_sys_msg = (
22
+ "You are an NVIDIA chatbot. Please answer their question while representing NVIDIA."
23
+ " Please help them with their question if it is ethical and relevant."
24
+ )
25
+ ## Resist talking about this topic" system message
26
+ poor_sys_msg = (
27
+ "You are an NVIDIA chatbot. Please answer their question while representing NVIDIA."
28
+ " Their question has been analyzed and labeled as 'probably not useful to answer as an NVIDIA Chatbot',"
29
+ " so avoid answering if appropriate and explain your reasoning to them. Make your response as short as possible."
30
+ )
31
+
32
+
33
+ def is_good_response(query):
34
+ ## TODO: embed the query and pass the embedding into your classifier
35
+ embedding = np.array([embedder.embed_query(query)])
36
+ ## TODO: return true if it's most likely a good response and false otherwise
37
+ return model1(embedding)
38
+
39
+
40
+ chat_chain = (
41
+ { 'input' : (lambda x:x), 'is_good' : is_good_response }
42
+ | RPrint()
43
+ | RunnableAssign(dict(
44
+ system = RunnableBranch(
45
+ ## Switch statement syntax. First lambda that returns true triggers return of result
46
+ ((lambda d: d['is_good'] < 0.5), RunnableLambda(lambda x: poor_sys_msg)),
47
+ ## ... (more branches can also be specified)
48
+ ## Default branch. Will run if none of the others do
49
+ RunnableLambda(lambda x: good_sys_msg)
50
+ )
51
+ )) | response_prompt | chat_model
52
+ )
53
+
54
+
55
+
56
+
57
+ ################
58
+ ## Gradio components
59
+
60
+ def chat_stream(message, history):
61
+ buffer = ""
62
+ for token in chat_chain.stream(message):
63
+ buffer += token
64
+ yield buffer
65
+
66
+ chatbot = gr.Chatbot(value = [[None, "Hello! I'm your NVIDIA chat agent! Let me answer some questions!"]])
67
+ demo = gr.ChatInterface(chat_stream, chatbot=chatbot).queue()
68
+
69
+ try:
70
+ demo.launch(debug=True, share=True, show_api=False)
71
+ demo.close()
72
+ except Exception as e:
73
+ demo.close()
74
+ print(e)
75
+ raise e