carlosrosas commited on
Commit
b5b6ecb
·
verified ·
1 Parent(s): d819771

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -20,7 +20,7 @@ max_new_tokens = 3000
20
  top_p = 0.95
21
  repetition_penalty = 1.2
22
 
23
- model_name = "PleIAs/Cassandre-RAG"
24
 
25
  # Initialize vLLM
26
  llm = LLM(model_name, max_model_len=8128)
@@ -56,8 +56,8 @@ def hybrid_search(text):
56
  return document, document_html
57
 
58
 
59
- class CassandreChatBot:
60
- def __init__(self, system_prompt="Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées."):
61
  self.system_prompt = system_prompt
62
 
63
  def predict(self, user_message):
@@ -110,8 +110,8 @@ def format_references(text):
110
 
111
  return ''.join(parts)
112
 
113
- # Initialize the CassandreChatBot
114
- cassandre_bot = CassandreChatBot()
115
 
116
  # CSS for styling
117
  css = """
@@ -155,20 +155,20 @@ css = """
155
 
156
  # Gradio interface
157
  def gradio_interface(user_message):
158
- response, sources = cassandre_bot.predict(user_message)
159
  return response, sources
160
 
161
  # Create Gradio app
162
  demo = gr.Blocks(css=css)
163
 
164
  with demo:
165
- gr.HTML("""<h1 style="text-align:center">Cassandre</h1>""")
166
  with gr.Row():
167
  with gr.Column(scale=2):
168
  text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
169
- text_button = gr.Button("Interroger Cassandre")
170
  with gr.Column(scale=3):
171
- text_output = gr.HTML(label="La réponse de Cassandre")
172
  with gr.Row():
173
  embedding_output = gr.HTML(label="Les sources utilisées")
174
 
 
20
  top_p = 0.95
21
  repetition_penalty = 1.2
22
 
23
+ model_name = "dataesr/"
24
 
25
  # Initialize vLLM
26
  llm = LLM(model_name, max_model_len=8128)
 
56
  return document, document_html
57
 
58
 
59
+ class ESRChatBot:
60
+ def __init__(self, system_prompt="Tu es ESR, le chatbot de l'Éducation nationale qui donne des réponses sourcées."):
61
  self.system_prompt = system_prompt
62
 
63
  def predict(self, user_message):
 
110
 
111
  return ''.join(parts)
112
 
113
+ # Initialize the ESRChatBot
114
+ ESR_bot = ESRChatBot()
115
 
116
  # CSS for styling
117
  css = """
 
155
 
156
  # Gradio interface
157
  def gradio_interface(user_message):
158
+ response, sources = ESR_bot.predict(user_message)
159
  return response, sources
160
 
161
  # Create Gradio app
162
  demo = gr.Blocks(css=css)
163
 
164
  with demo:
165
+ gr.HTML("""<h1 style="text-align:center">ESR</h1>""")
166
  with gr.Row():
167
  with gr.Column(scale=2):
168
  text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
169
+ text_button = gr.Button("Interroger ESR")
170
  with gr.Column(scale=3):
171
+ text_output = gr.HTML(label="La réponse de ESR")
172
  with gr.Row():
173
  embedding_output = gr.HTML(label="Les sources utilisées")
174