carlosrosas commited on
Commit
757e2cd
·
verified ·
1 Parent(s): ef9b2a7

Upload 7 files

Browse files
Files changed (7) hide show
  1. README(4).md +13 -0
  2. app(5).py +179 -0
  3. gitattributes(7) +7 -0
  4. gitattributes(8) +36 -0
  5. pleias(1).png +0 -0
  6. requirements(2).txt +12 -0
  7. theme_builder(2).py +3 -0
README(4).md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Cassandre
3
+ emoji: 📜
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.50.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app(5).py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import re
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
+ from vllm import LLM, SamplingParams
5
+ import torch
6
+ import gradio as gr
7
+ import json
8
+ import os
9
+ import shutil
10
+ import requests
11
+ import lancedb
12
+ import pandas as pd
13
+
14
+ # Define the device
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # Define variables
18
+ temperature = 0.7
19
+ max_new_tokens = 3000
20
+ top_p = 0.95
21
+ repetition_penalty = 1.2
22
+
23
+ model_name = "PleIAs/Cassandre-RAG"
24
+
25
+ # Initialize vLLM
26
+ llm = LLM(model_name, max_model_len=8128)
27
+
28
+ # Connect to the LanceDB database
29
+ db = lancedb.connect("content19/lancedb_data")
30
+ table = db.open_table("edunat19")
31
+
32
+ def hybrid_search(text):
33
+ results = table.search(text, query_type="hybrid").limit(5).to_pandas()
34
+
35
+ # Add a check for duplicate hashes
36
+ seen_hashes = set()
37
+
38
+ document = []
39
+ document_html = []
40
+ for _, row in results.iterrows():
41
+ hash_id = str(row['hash'])
42
+
43
+ # Skip if we've already seen this hash
44
+ if hash_id in seen_hashes:
45
+ continue
46
+
47
+ seen_hashes.add(hash_id)
48
+ title = row['section']
49
+ content = row['text']
50
+
51
+ document.append(f"**{hash_id}**\n{title}\n{content}")
52
+ document_html.append(f'<div class="source" id="{hash_id}"><p><b>{hash_id}</b> : {title}<br>{content}</div>')
53
+
54
+ document = "\n".join(document)
55
+ document_html = '<div id="source_listing">' + "".join(document_html) + "</div>"
56
+ return document, document_html
57
+
58
+
59
+ class CassandreChatBot:
60
+ def __init__(self, system_prompt="Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées."):
61
+ self.system_prompt = system_prompt
62
+
63
+ def predict(self, user_message):
64
+ fiches, fiches_html = hybrid_search(user_message)
65
+ sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_new_tokens, presence_penalty=repetition_penalty, stop=["#END#"])
66
+
67
+ detailed_prompt = f"""### Query ###\n{user_message}\n\n### Source ###\n{fiches}\n\n### Answer ###\n"""
68
+
69
+ prompts = [detailed_prompt]
70
+ outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
71
+ generated_text = outputs[0].outputs[0].text
72
+ generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>"
73
+ fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html
74
+ return generated_text, fiches_html
75
+
76
+ def format_references(text):
77
+ ref_start_marker = '<ref text="'
78
+ ref_end_marker = '</ref>'
79
+
80
+ parts = []
81
+ current_pos = 0
82
+ ref_number = 1
83
+
84
+ while True:
85
+ start_pos = text.find(ref_start_marker, current_pos)
86
+ if start_pos == -1:
87
+ parts.append(text[current_pos:])
88
+ break
89
+
90
+ parts.append(text[current_pos:start_pos])
91
+
92
+ end_pos = text.find('">', start_pos)
93
+ if end_pos == -1:
94
+ break
95
+
96
+ ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip()
97
+ ref_text_encoded = ref_text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
98
+
99
+ ref_end_pos = text.find(ref_end_marker, end_pos)
100
+ if ref_end_pos == -1:
101
+ break
102
+
103
+ ref_id = text[end_pos + 2:ref_end_pos].strip()
104
+
105
+ tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[{ref_number}]</a></span>'
106
+ parts.append(tooltip_html)
107
+
108
+ current_pos = ref_end_pos + len(ref_end_marker)
109
+ ref_number = ref_number + 1
110
+
111
+ return ''.join(parts)
112
+
113
+ # Initialize the CassandreChatBot
114
+ cassandre_bot = CassandreChatBot()
115
+
116
+ # CSS for styling
117
+ css = """
118
+ .generation {
119
+ margin-left:2em;
120
+ margin-right:2em;
121
+ }
122
+ :target {
123
+ background-color: #CCF3DF;
124
+ }
125
+ .source {
126
+ float:left;
127
+ max-width:17%;
128
+ margin-left:2%;
129
+ }
130
+ .tooltip {
131
+ position: relative;
132
+ cursor: pointer;
133
+ font-variant-position: super;
134
+ color: #97999b;
135
+ }
136
+
137
+ .tooltip:hover::after {
138
+ content: attr(data-text);
139
+ position: absolute;
140
+ left: 0;
141
+ top: 120%;
142
+ white-space: pre-wrap;
143
+ width: 500px;
144
+ max-width: 500px;
145
+ z-index: 1;
146
+ background-color: #f9f9f9;
147
+ color: #000;
148
+ border: 1px solid #ddd;
149
+ border-radius: 5px;
150
+ padding: 5px;
151
+ display: block;
152
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
153
+ }
154
+ """
155
+
156
+ # Gradio interface
157
+ def gradio_interface(user_message):
158
+ response, sources = cassandre_bot.predict(user_message)
159
+ return response, sources
160
+
161
+ # Create Gradio app
162
+ demo = gr.Blocks(css=css)
163
+
164
+ with demo:
165
+ gr.HTML("""<h1 style="text-align:center">Cassandre</h1>""")
166
+ with gr.Row():
167
+ with gr.Column(scale=2):
168
+ text_input = gr.Textbox(label="Votre question ou votre instruction", lines=3)
169
+ text_button = gr.Button("Interroger Cassandre")
170
+ with gr.Column(scale=3):
171
+ text_output = gr.HTML(label="La réponse de Cassandre")
172
+ with gr.Row():
173
+ embedding_output = gr.HTML(label="Les sources utilisées")
174
+
175
+ text_button.click(gradio_interface, inputs=text_input, outputs=[text_output, embedding_output])
176
+
177
+ # Launch the app
178
+ if __name__ == "__main__":
179
+ demo.launch()
gitattributes(7) ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ education_corrected/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
2
+ education_corrected/e150eb41-e894-45c4-b97c-80ced9ff2123/data_level0.bin filter=lfs diff=lfs merge=lfs -text
3
+ education_corrected/a9ac8f33-9498-450a-ae99-f116efb66330/data_level0.bin filter=lfs diff=lfs merge=lfs -text
4
+ education_corrected/6af97eb5-0cfa-40b2-a4df-732ca13bd66a/data_level0.bin filter=lfs diff=lfs merge=lfs -text
5
+ content/lancedb_data/eduv1.lance/_indices/fts/55ac048af92d47c0903552a94300d4e3.store filter=lfs diff=lfs merge=lfs -text
6
+ content18/lancedb_data/edunat18.lance/_indices/fts/1655747b9eec413b8f22e287af9a4e8e.store filter=lfs diff=lfs merge=lfs -text
7
+ content19/lancedb_data/edunat19.lance/_indices/fts/28b0ccef3dd3401396a9ab134c426bc6.store filter=lfs diff=lfs merge=lfs -text
gitattributes(8) ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ education_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
pleias(1).png ADDED
requirements(2).txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ einops
4
+ accelerate
5
+ tiktoken
6
+ scipy
7
+ vllm
8
+ lancedb
9
+ sentence_transformers
10
+ gradio
11
+ pandas
12
+ tantivy
theme_builder(2).py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import gradio as gr
2
+
3
+ gr.themes.builder()