borodache commited on
Commit
53e86fd
·
verified ·
1 Parent(s): 9b2a7ee

Copy from previous failed (docker issue) Gardio space - to a new space (also correcting the previous typo in space name)

Browse files
Files changed (9) hide show
  1. .gitattributes +35 -35
  2. README.md +27 -13
  3. app.py +117 -0
  4. generator.py +63 -0
  5. rag_agent.py +125 -0
  6. requirements.txt +4 -0
  7. reranker.py +37 -0
  8. retriever.py +39 -0
  9. text_embedder_encoder.py +55 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,27 @@
1
- ---
2
- title: Hebrew Dentist
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: A RAG Agent which works as a Hebrew speaking dentist
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Hebrew Dentsit
3
+ emoji: 🏢
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.10.0
8
+ app_file: app.py
9
+ pinned: false
10
+ short_description: A RAG agent Hebrew Speaking Dentist
11
+ ---
12
+
13
+ Do you want to consult with a Dentist? Speaking Hebrew? Consulting with Dentist can be expensive... This is why I had built a Hebrew RAG Dentist Agent, which you can talk to.
14
+
15
+ Warning: The Agent (Chatbot) can still hallucinate and make up "fake" facts and shouldn’t be an alternative for an expert Dentist. the use of this Chatbot is on your responsibility only.
16
+
17
+ This RAG Agent based on Q&A data collected from 3 top Israeli forums. Data was collected using scraper, and saved into a SQL DB. Then, the titles & questions were embedded into vectors using free 'MPA/sambert' HuggingFace Encoder Model (this model found to be performing well on Hebrew Medical Jargon). The Vectors were stored a hundread at a time, into NoSQL Pinecone Vector Database, with answer_id as metadata.
18
+ The answers were converted into vector embedding using the same free Encoder ('MPA/sambert'), and stored in Pinecone with different key and with the answer as metadata
19
+ Now, all is left is the the RAG Agent which is composed from a Retriever, Reranker, and a Generator:
20
+ 4) The Retriever embeds the user question (using the free 'MPA/sambert' HuggingFace Encoder Model) uses an ANN search with a cosine similarity metric and the top_k variable equals to 50.
21
+ 5) The Reranker fetches the answers vectors suing their list of top_k ids and answers as metadata in a second scan from the PineCone database resorts the answers, then cosine similarity is calculated using the sklearn method. Afterwards, it selects the the top_n (equal to 5) answers, when each answer should be similar to the question embedding with a threshold of 0.7 or higher.
22
+ 6) The Generator used is from a paid API -Anthropic Claude Sonnet 3.5 - a decoder that is not trained over the medical jargon - however with the right prompt and the right context the results are pretty good.
23
+
24
+ The whole work from inception to completion was done by me (Eli Borodach)
25
+
26
+
27
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+
4
+
5
+ from rag_agent import RAGAgent
6
+
7
+
8
+ rag_agent = RAGAgent()
9
+
10
+
11
+ class ChatBot:
12
+ def __init__(self, rag_agent):
13
+ self.message_history = []
14
+ self.rag_agent = rag_agent
15
+
16
+ def get_response(self, message):
17
+ return self.rag_agent.get_response(message)
18
+
19
+ def chat(self, message):
20
+ time.sleep(1)
21
+ bot_response = self.get_response(message)
22
+ self.message_history.append((message, bot_response))
23
+ return bot_response
24
+
25
+
26
+ def create_chat_interface(rag_agent=rag_agent):
27
+ chatbot = ChatBot(rag_agent=rag_agent)
28
+
29
+ custom_css = """
30
+ #chatbot {
31
+ direction: rtl;
32
+ height: 400px;
33
+ }
34
+ .message {
35
+ font-size: 16px;
36
+ text-align: right;
37
+ }
38
+ .message-wrap {
39
+ direction: rtl !important;
40
+ }
41
+ .message-wrap > div {
42
+ direction: rtl !important;
43
+ text-align: right !important;
44
+ }
45
+ .input-box {
46
+ direction: rtl !important;
47
+ text-align: right !important;
48
+ }
49
+ .container {
50
+ direction: rtl;
51
+ }
52
+ .contain {
53
+ direction: rtl !important;
54
+ }
55
+ .bubble {
56
+ direction: rtl !important;
57
+ text-align: right !important;
58
+ }
59
+ textarea, input {
60
+ direction: rtl !important;
61
+ text-align: right !important;
62
+ }
63
+ .user-message, .bot-message {
64
+ direction: rtl !important;
65
+ text-align: right !important;
66
+ }
67
+ """
68
+
69
+ with gr.Blocks(css=custom_css) as interface:
70
+ with gr.Column(elem_classes="container"):
71
+ gr.Markdown("רופא שיניים אלקטרוני", rtl=True)
72
+
73
+ chatbot_component = gr.Chatbot(
74
+ [],
75
+ elem_id="chatbot",
76
+ height=400,
77
+ rtl=True,
78
+ elem_classes="message-wrap"
79
+ )
80
+
81
+ with gr.Row():
82
+ submit_btn = gr.Button("שלח", variant="primary")
83
+ txt = gr.Textbox(
84
+ show_label=False,
85
+ placeholder="הקלד את ההודעה שלך כאן...",
86
+ container=False,
87
+ elem_classes="input-box",
88
+ rtl=True
89
+ )
90
+
91
+ clear_btn = gr.Button("נקה צ'אט")
92
+
93
+ def user_message(user_message, history):
94
+ return "", history + [[user_message, None]]
95
+
96
+ def bot_message(history):
97
+ user_message = history[-1][0]
98
+ bot_response = chatbot.chat(user_message)
99
+ history[-1][1] = bot_response
100
+ return history
101
+
102
+ def clear_summary():
103
+ rag_agent.conversation_summary = ""
104
+ rag_agent.messages = []
105
+
106
+ submit_btn.click(user_message, [txt, chatbot_component], [txt, chatbot_component], queue=False).then(
107
+ bot_message, chatbot_component, chatbot_component
108
+ )
109
+
110
+ clear_btn.click(clear_summary, None, chatbot_component, queue=False)
111
+
112
+ return interface
113
+
114
+
115
+ # Launch the interface
116
+ chat_interface = create_chat_interface(rag_agent=rag_agent)
117
+ chat_interface.launch(share=True)
generator.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from retriever import Retriever
2
+ from reranker import Reranker
3
+ from anthropic import Anthropic
4
+ from typing import List
5
+ import os
6
+
7
+
8
+ retriever = Retriever()
9
+ reranker = Reranker()
10
+
11
+
12
+ class RAGAgent:
13
+ def __init__(
14
+ self,
15
+ retriever=retriever,
16
+ reranker=reranker,
17
+ anthropic_api_key: str = os.environ["anthropic_api_key"],
18
+ model: str = "claude-3-5-sonnet-20241022",
19
+ max_tokens: int = 1024,
20
+ temperature: float = 0.0,
21
+ ):
22
+ self.retriever = retriever
23
+ self.reranker = reranker
24
+ self.client = Anthropic(api_key=anthropic_api_key)
25
+ self.model = model
26
+ self.max_tokens = max_tokens
27
+ self.temperature = temperature
28
+
29
+ def get_context(self, query: str) -> List[str]:
30
+ # Get initial candidates from retriever
31
+ retrieved_docs = self.retriever.search_similar(query)
32
+
33
+ # Rerank the candidates
34
+ context = self.reranker.rerank(query, retrieved_docs)
35
+
36
+ return context
37
+
38
+ def generate_prompt(self, context: List[str]) -> str:
39
+ context = "\n".join(context)
40
+ prompt = f"""
41
+ "אתה רופא שיניים, דובר עברית בלבד. קוראים לך 'רופא השיניים העברי האלקטרוני הראשון'. ענה למטופל על השאלה שלו על סמך הקונטקס הבא: {context}. הוסף כמה שיותר פרטים, ודאג שהתחביר יהיה תקין ויפה. תעצור כשאתה מרגיש שמיצית את עצמך. אל תמציא דברים. ואל תענה בשפות שהן לא עברית.
42
+ """
43
+ return prompt
44
+
45
+ def get_response(self, question: str) -> str:
46
+ # Get relevant context
47
+ context = self.get_context(question)
48
+
49
+ # Generate prompt with context
50
+ prompt = self.generate_prompt(context)
51
+
52
+ # Get response from Claude
53
+ response = self.client.messages.create(
54
+ model=self.model,
55
+ max_tokens=self.max_tokens,
56
+ temperature=self.temperature,
57
+ messages=[
58
+ {"role": "assistant", "content": prompt},
59
+ {"role": "user", "content": f"{question}"}
60
+ ]
61
+ )
62
+
63
+ return response.content[0].text
rag_agent.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from anthropic import Anthropic
2
+ from typing import List
3
+ import os
4
+
5
+
6
+ from retriever import Retriever
7
+ from reranker import Reranker
8
+ from text_embedder_encoder import TextEmbedder, encoder_model_name
9
+
10
+
11
+ retriever = Retriever()
12
+ reranker = Reranker()
13
+
14
+
15
+ class RAGAgent:
16
+ def __init__(
17
+ self,
18
+ retriever=retriever,
19
+ reranker=reranker,
20
+ anthropic_api_key: str = os.environ["anthropic_api_key"],
21
+ model_name: str = "claude-3-5-sonnet-20241022",
22
+ max_tokens: int = 1024,
23
+ temperature: float = 0.0,
24
+ ):
25
+ self.retriever = retriever
26
+ self.reranker = reranker
27
+ self.client = Anthropic(api_key=anthropic_api_key)
28
+ self.model_name = model_name
29
+ self.max_tokens = max_tokens
30
+ self.temperature = temperature
31
+ self.text_embedder = TextEmbedder()
32
+ self.conversation_summary = ""
33
+ self.messages = []
34
+
35
+ def get_context(self, query: str) -> List[str]:
36
+ # Get initial candidates from retriever
37
+ query_vector = self.text_embedder.encode(query)
38
+ retrieved_answers_ids = self.retriever.search_similar(query_vector)
39
+ # Rerank the candidates
40
+ context = self.reranker.rerank(query_vector, retrieved_answers_ids)
41
+
42
+ return context
43
+
44
+ def generate_prompt(self, context: List[str], conversation_summary: str = "") -> str:
45
+ context = "\n".join(context)
46
+ summary_context = f"\nסיכום השיחה עד כה:\n{conversation_summary}" if conversation_summary else ""
47
+
48
+ prompt = f"""
49
+ אתה רופא שיניים, דובר עברית בלבד. קוראים לך 'רופא השיניים האלקטרוני העברי הראשון'.{summary_context}
50
+ ענה למטופל על השאלה שלו על סמך הקונטקס הבא: {context}.
51
+ הוסף כמה שיותר פרטים, ודאג שהתחביר יהיה תקין ויפה.
52
+ תעצור כשאתה מרגיש שמיצית את עצמך. אל תמציא דברים.
53
+ ואל תענה בשפות שהן לא עברית.
54
+ """
55
+ return prompt
56
+
57
+ def update_summary(self, question: str, answer: str) -> str:
58
+ """Update the conversation summary with the new interaction"""
59
+ summary_prompt = {
60
+ "model": self.model_name,
61
+ "max_tokens": 500,
62
+ "temperature": 0.0,
63
+ "messages": [
64
+ {
65
+ "role": "user",
66
+ "content": f"""סכם את השיחה בעברית, הנה סיכום השיחה עד כה:
67
+ {self.conversation_summary if self.conversation_summary else "אין שיחה קודמת."}
68
+
69
+ אינטראקציה חדשה:
70
+ שאלת המטופל: {question}
71
+ תשובת הרופא: {answer}
72
+
73
+ אנא ספק סיכום מעודכן שכולל את המידע הרפואי מהסיכום הקודם בנוסף לדגש על האינטרקציה החדשה. הסיכום צריך להיות תמציתי עד 100 מילה.
74
+ ותר על מידע לא רלוונטי מהסיכומים הקודמים"""
75
+ }
76
+ ]
77
+ }
78
+
79
+ try:
80
+ response = self.client.messages.create(**summary_prompt)
81
+ self.conversation_summary = response.content[0].text
82
+ return self.conversation_summary
83
+ except Exception as e:
84
+ print(f"Error updating summary: {e}")
85
+ return self.get_basic_summary()
86
+
87
+ def get_basic_summary(self) -> str:
88
+ """Fallback method for basic summary"""
89
+ summary = []
90
+ for i in range(0, len(self.messages), 2):
91
+ if i + 1 < len(self.messages):
92
+ summary.append(f"שאלת המטופל: {self.messages[i]['content']}")
93
+ summary.append(f"תשובת הרופא שיניים: {self.messages[i + 1]['content']}\n")
94
+ return "\n".join(summary)
95
+
96
+ def get_response(self, question: str) -> str:
97
+ # Get relevant context
98
+ context = self.get_context(question + self.conversation_summary)
99
+
100
+ # Generate prompt with context and current conversation summary
101
+ prompt = self.generate_prompt(context, self.conversation_summary)
102
+
103
+ # Get response from Claude
104
+ response = self.client.messages.create(
105
+ model=self.model_name,
106
+ max_tokens=self.max_tokens,
107
+ temperature=self.temperature,
108
+ messages=[
109
+ {"role": "assistant", "content": prompt},
110
+ {"role": "user", "content": f"{question}"}
111
+ ]
112
+ )
113
+
114
+ answer = response.content[0].text
115
+
116
+ # Store messages for history
117
+ self.messages.extend([
118
+ {"role": "user", "content": question},
119
+ {"role": "assistant", "content": answer}
120
+ ])
121
+
122
+ # Update conversation summary
123
+ self.update_summary(question, answer)
124
+
125
+ return answer
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ anthropic==0.42.0
2
+ gradio==4.44.1
3
+ pinecone==5.4.2
4
+ sentence-transformers==3.2.1
reranker.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+ import os
4
+
5
+
6
+ from text_embedder_encoder import encoder_model_name
7
+
8
+
9
+ class Reranker:
10
+ def __init__(self,
11
+ pinecone_api_key=os.environ["pinecone_api_key"],
12
+ answer_index_name=f"hebrew-dentist-answers-{encoder_model_name.replace('/', '-')}".lower()):
13
+ self.pc = Pinecone(api_key=pinecone_api_key)
14
+ self.answer_index_name = answer_index_name
15
+
16
+ def rerank(self, query_vector, retrieved_answers_ids, top_n=5):
17
+ # Encode query and documents
18
+ try:
19
+ index = self.pc.Index(self.answer_index_name)
20
+ fetch_response = index.fetch(ids=retrieved_answers_ids)
21
+
22
+ doc_embeddings = []
23
+ answers = []
24
+ for i in range(len(retrieved_answers_ids)):
25
+ doc_embeddings.append(fetch_response['vectors'][retrieved_answers_ids[i]]['values'])
26
+ answers.append(fetch_response['vectors'][retrieved_answers_ids[i]]['metadata']['answer'])
27
+
28
+ similarity_scores = cosine_similarity([query_vector], doc_embeddings)[0]
29
+ similarity_scores_with_idxes = list(zip(similarity_scores, range(len(similarity_scores))))
30
+ similarity_scores_with_idxes.sort(reverse=True)
31
+ similarity_scores_with_idxes_final = similarity_scores_with_idxes[:top_n]
32
+ reranked_answers = [answers[idx] for score, idx in similarity_scores_with_idxes_final if score >= 0.7]
33
+
34
+ return reranked_answers
35
+ except Exception as e:
36
+ print(f"Error performing rerank: {e}")
37
+ return []
retriever.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone
2
+ import os
3
+
4
+ from text_embedder_encoder import encoder_model_name
5
+
6
+
7
+ class Retriever:
8
+ def __init__(self,
9
+ pinecone_api_key=os.environ["pinecone_api_key"],
10
+ question_index_name=f"hebrew-dentist-questions-{encoder_model_name.replace('/', '-')}".lower()):
11
+ # Initialize Pinecone connection
12
+ self.pc = Pinecone(api_key=pinecone_api_key)
13
+ self.question_index_name = question_index_name
14
+
15
+ def search_similar(self, query_vector, top_k=50):
16
+ """
17
+ Search for similar content using vector similarity in Pinecone
18
+ """
19
+ try:
20
+
21
+ # Get Pinecone index
22
+ index = self.pc.Index(self.question_index_name)
23
+
24
+ # Execute search
25
+ results = index.query(
26
+ vector=query_vector,
27
+ top_k=top_k,
28
+ include_metadata=True,
29
+ )
30
+
31
+ answers_records_ids = []
32
+ for match in results['matches']:
33
+ answers_records_ids.append(
34
+ ':'.join(match['id'].split(':')[:-1]) + ":" + str(int(match['metadata']['answer_id'])))
35
+
36
+ return answers_records_ids
37
+ except Exception as e:
38
+ print(f"Error performing retriever: {e}")
39
+ return []
text_embedder_encoder.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+
6
+ encoder_model_name = 'MPA/sambert'
7
+
8
+
9
+ class TextEmbedder:
10
+ def __init__(self):
11
+ """
12
+ Initialize the Hebrew text embedder using dictabert-large-heq model
13
+ """
14
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ self.model = SentenceTransformer(encoder_model_name)
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.model.to(self.device)
18
+ self.model.eval()
19
+
20
+ def encode(self, text) -> np.ndarray:
21
+ """
22
+ Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
23
+
24
+ Args:
25
+ text (str): Hebrew text to encode
26
+ model_name (str): Name of the model to use
27
+ # max_seq_length (int): Maximum sequence length for the model
28
+ strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
29
+
30
+ Returns:
31
+ numpy.ndarray: Text embedding
32
+ """
33
+ # Get embeddings for the text
34
+ embeddings = [float(x) for x in self.model.encode([text])[0]]
35
+
36
+ return embeddings
37
+
38
+ # def encode_many(self, texts: List[str]) -> np.ndarray:
39
+ # """
40
+ # Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
41
+ #
42
+ # Args:
43
+ # text (str): Hebrew text to encode
44
+ # model_name (str): Name of the model to use
45
+ # # max_seq_length (int): Maximum sequence length for the model
46
+ # strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
47
+ #
48
+ # Returns:
49
+ # numpy.ndarray: Text embedding
50
+ # """
51
+ # # Get embeddings for the text
52
+ # embeddings = self.model.encode(texts)
53
+ # embeddings = [[float(x) for x in embedding] for embedding in embeddings]
54
+ #
55
+ # return embeddings