ds-EkaCare commited on
Commit
f98d92f
·
verified ·
1 Parent(s): 8a33244

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -183
app.py CHANGED
@@ -1,204 +1,177 @@
1
- import os
2
  import gradio as gr
3
  import torch
4
  from transformers import AutoTokenizer, AutoModel
5
  import numpy as np
6
- from typing import List, Tuple
7
  import pandas as pd
8
- from huggingface_hub import login
 
9
 
10
- login(os.getenv("HF_TOKEN"))
11
- # Model configuration
12
- MODEL_NAME = "ekacare/parrotlet-e"
13
 
14
- class ParrotletEmbedder:
15
- def __init__(self, model_name: str):
16
- """Initialize the Parrotlet-E model and tokenizer."""
 
 
 
 
 
 
 
 
17
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Loading model on {self.device}...")
19
-
20
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
21
- self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
 
22
  self.model.to(self.device)
23
  self.model.eval()
24
-
 
 
 
 
 
 
 
 
 
 
 
 
25
  def mean_pooling(self, model_output, attention_mask):
26
- """Perform mean pooling on model output."""
27
  token_embeddings = model_output[0]
28
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
29
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
30
-
31
- def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
32
- """Encode texts into embeddings."""
33
- all_embeddings = []
34
-
 
 
 
35
  with torch.no_grad():
36
- for i in range(0, len(texts), batch_size):
37
- batch_texts = texts[i:i + batch_size]
38
-
39
- encoded_input = self.tokenizer(
40
- batch_texts,
41
- padding=True,
42
- truncation=True,
43
- max_length=512,
44
- return_tensors='pt'
45
- ).to(self.device)
46
-
47
- model_output = self.model(**encoded_input)
48
- embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
49
-
50
- # Normalize embeddings
51
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
52
- all_embeddings.append(embeddings.cpu().numpy())
53
-
54
- return np.vstack(all_embeddings)
55
-
56
- def compute_similarity(self, text1: str, text2: str) -> float:
57
- """Compute cosine similarity between two texts."""
58
- embeddings = self.encode([text1, text2])
59
- similarity = np.dot(embeddings[0], embeddings[1])
60
- return float(similarity)
61
-
62
- # Initialize model
63
- embedder = ParrotletEmbedder(MODEL_NAME)
64
-
65
- def compute_pairwise_similarity(query: str, documents: str) -> Tuple[pd.DataFrame, str]:
66
- """
67
- Compute similarity between query and multiple documents.
68
-
69
- Args:
70
- query: The query text
71
- documents: Documents separated by newlines
72
-
73
- Returns:
74
- DataFrame with documents and similarity scores, and status message
75
- """
 
 
 
 
76
  if not query.strip():
77
- return pd.DataFrame(), "⚠️ Please enter a query text."
78
-
79
- if not documents.strip():
80
- return pd.DataFrame(), "⚠️ Please enter at least one document."
81
-
82
- # Split documents by newlines and filter empty lines
83
- doc_list = [doc.strip() for doc in documents.split('\n') if doc.strip()]
84
-
85
- if len(doc_list) == 0:
86
- return pd.DataFrame(), "⚠️ No valid documents found."
87
-
88
  try:
89
- # Encode query and documents
90
- query_embedding = embedder.encode([query])
91
- doc_embeddings = embedder.encode(doc_list)
92
-
93
- # Compute similarities
94
- similarities = np.dot(doc_embeddings, query_embedding.T).flatten()
95
-
96
- # Create results dataframe
97
- results = pd.DataFrame({
98
- 'Rank': range(1, len(doc_list) + 1),
99
- 'Document': doc_list,
100
- 'Similarity Score': similarities
101
- })
102
-
103
- # Sort by similarity
104
- results = results.sort_values('Similarity Score', ascending=False).reset_index(drop=True)
105
- results['Rank'] = range(1, len(results) + 1)
106
-
107
- status = f"✅ Successfully computed similarities for {len(doc_list)} documents."
108
- return results, status
109
-
110
  except Exception as e:
111
  return pd.DataFrame(), f"❌ Error: {str(e)}"
112
 
113
- def compute_single_similarity(text1: str, text2: str) -> Tuple[str, str]:
114
- """
115
- Compute similarity between two texts.
116
-
117
- Args:
118
- text1: First text
119
- text2: Second text
120
-
121
- Returns:
122
- Similarity score and status message
123
- """
124
- if not text1.strip() or not text2.strip():
125
- return "", "⚠️ Please enter both texts."
126
-
127
- try:
128
- similarity = embedder.compute_similarity(text1, text2)
129
- score_display = f"### Similarity Score: {similarity:.4f}"
130
- status = "✅ Similarity computed successfully."
131
- return score_display, status
132
-
133
- except Exception as e:
134
- return "", f"❌ Error: {str(e)}"
135
-
136
- # Create Gradio interface
137
- with gr.Blocks(title="Parrotlet-e: Indic Medical Embedding Model", theme=gr.themes.Soft()) as demo:
138
-
139
- with gr.Tab("Query-Document Matching"):
140
- gr.Markdown("""
141
- ### 📄 Semantic Search""")
142
-
143
- with gr.Row():
144
- with gr.Column():
145
- query_input = gr.Textbox(
146
- label="term1",
147
- placeholder="",
148
- lines=1
149
- )
150
- documents_input = gr.Textbox(
151
- label="term2",
152
- placeholder="",
153
- lines=1
154
- )
155
- search_btn = gr.Button("🔍 Search", variant="primary")
156
-
157
- with gr.Column():
158
- search_output = gr.Dataframe(
159
- label="Results",
160
- headers=["Rank", "Document", "Similarity Score"],
161
- datatype=["number", "str", "number"],
162
- wrap=True
163
- )
164
- search_status = gr.Textbox(label="Status", interactive=False)
165
-
166
- search_btn.click(
167
- fn=compute_pairwise_similarity,
168
- inputs=[query_input, documents_input],
169
- outputs=[search_output, search_status]
170
- )
171
-
172
- with gr.Tab("Pairwise Similarity"):
173
- gr.Markdown("""
174
- ### 🔗 Compare Two Texts
175
- Compute semantic similarity between any two medical texts (score ranges from -1 to 1).
176
- """)
177
-
178
- with gr.Row():
179
- with gr.Column():
180
- text1_input = gr.Textbox(
181
- label="Text 1",
182
- placeholder="Enter first text...\nExample: हृदय रोग के लक्षण",
183
- lines=5
184
- )
185
- text2_input = gr.Textbox(
186
- label="Text 2",
187
- placeholder="Enter second text...\nExample: छाती में दर्द और सांस फूलना",
188
- lines=5
189
- )
190
- similarity_btn = gr.Button("⚡ Calculate Similarity", variant="primary")
191
-
192
- with gr.Column():
193
- similarity_output = gr.Markdown(label="Similarity Score")
194
- similarity_status = gr.Textbox(label="Status", interactive=False)
195
-
196
- similarity_btn.click(
197
- fn=compute_single_similarity,
198
- inputs=[text1_input, text2_input],
199
- outputs=[similarity_output, similarity_status]
200
- )
201
 
202
- # Launch the app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  if __name__ == "__main__":
204
- demo.launch(share=True)
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModel
4
  import numpy as np
5
+ from typing import List, Dict
6
  import pandas as pd
7
+ import os
8
+ from pinecone import Pinecone
9
 
 
 
 
10
 
11
+
12
+ if PINECONE_API_KEY is None:
13
+ raise ValueError("Please set PINECONE_API_KEY as an environment variable.")
14
+
15
+
16
+ # =========================
17
+ # Retriever Class
18
+ # =========================
19
+ class ParrotletRetriever:
20
+ def __init__(self, model_name: str, pinecone_api_key: str, index_name: str):
21
+ """Initialize model and Pinecone client."""
22
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ print(f"🚀 Loading model on {self.device}...")
24
+
25
+ # Load tokenizer and model
26
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
27
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, token=HF_TOKEN)
28
  self.model.to(self.device)
29
  self.model.eval()
30
+
31
+ # Pinecone initialization (unchanged)
32
+ self.pinecone_namespace = os.environ.get("NAMESPACE")
33
+ self.pinecone_client = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
34
+ self.pinecone_index = self.pinecone_client.Index(host=os.environ.get("PINECONE_HOST"))
35
+
36
+ print(f"Connected to Pinecone index: {index_name}")
37
+ if self.pinecone_namespace:
38
+ print(f"🔹 Using namespace: {self.pinecone_namespace}")
39
+
40
+ # --------------------------
41
+ # Mean Pooling
42
+ # --------------------------
43
  def mean_pooling(self, model_output, attention_mask):
44
+ """Mean pooling for sentence embeddings."""
45
  token_embeddings = model_output[0]
46
  input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
47
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
48
+ input_mask_expanded.sum(1), min=1e-9
49
+ )
50
+
51
+ # --------------------------
52
+ # Text Encoder
53
+ # --------------------------
54
+ def encode(self, texts: List[str]) -> np.ndarray:
55
+ """Encode text into normalized embeddings."""
56
  with torch.no_grad():
57
+ encoded_input = self.tokenizer(
58
+ texts, padding=True, truncation=True, max_length=512, return_tensors="pt"
59
+ ).to(self.device)
60
+
61
+ model_output = self.model(**encoded_input)
62
+ embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"])
63
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
64
+ return embeddings.cpu().numpy()
65
+
66
+ # --------------------------
67
+ # Pinecone Search
68
+ # --------------------------
69
+ def search(self, query: str, top_k: int = 5) -> List[Dict]:
70
+ """Search Pinecone index."""
71
+ query_vector = self.encode([query])[0]
72
+
73
+ results = self.pinecone_index.query(
74
+ namespace=self.pinecone_namespace,
75
+ vector=query_vector.tolist(),
76
+ top_k=top_k,
77
+ include_metadata=True,
78
+ include_values=False,
79
+ )
80
+
81
+ docs = []
82
+ for i, match in enumerate(results["matches"]):
83
+ text = match["metadata"].get("text", "[No text metadata]")
84
+ docs.append({
85
+ "Rank": i + 1,
86
+ "Score": f"{match['score']:.4f}",
87
+ "Document": text
88
+ "ID": match["concept_id"],
89
+ })
90
+
91
+ return docs
92
+
93
+
94
+
95
+ MODEL_NAME = "ekacare/parrotlet-e"
96
+ retriever = ParrotletRetriever(MODEL_NAME, PINECONE_API_KEY, PINECONE_INDEX_NAME)
97
+
98
+
99
+ def retrieve_documents(query: str, top_k: int = 5):
100
+ """Perform retrieval and return results."""
101
  if not query.strip():
102
+ return pd.DataFrame(), "⚠️ Please enter a valid query."
103
+
 
 
 
 
 
 
 
 
 
104
  try:
105
+ results = retriever.search(query, top_k)
106
+ if not results:
107
+ return pd.DataFrame(), "⚠️ No results found."
108
+
109
+ df = pd.DataFrame(results)
110
+ status = f"✅ Retrieved top {len(results)} documents."
111
+ return df, status
112
+
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  except Exception as e:
114
  return pd.DataFrame(), f"❌ Error: {str(e)}"
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ # =========================
118
+ # Gradio Interface
119
+ # =========================
120
+ SAMPLE_QUERIES = [
121
+ "dm t2",
122
+ "छाती में दर्द",
123
+ "talenovu",
124
+ "వాంతులు"
125
+ ]
126
+
127
+ with gr.Blocks(title="Parrotlet-E Retrieval (Pinecone)", theme=gr.themes.Soft()) as demo:
128
+ gr.Markdown("""
129
+ # 🦜 Parrotlet-E: Indic Medical Entities Retrieval
130
+ Retrieve top-5 semantically similar medical entities from Pinecone using Parrotlet-E embeddings.
131
+ """)
132
+
133
+ with gr.Row():
134
+ with gr.Column(scale=1):
135
+ query_input = gr.Textbox(
136
+ label="Enter a medical term",
137
+ placeholder="Type your query here...",
138
+ lines=1,
139
+ )
140
+
141
+ examples = gr.Examples(
142
+ examples=SAMPLE_QUERIES,
143
+ inputs=query_input,
144
+ label="Example Queries"
145
+ )
146
+
147
+ top_k_slider = gr.Slider(
148
+ minimum=1,
149
+ maximum=10,
150
+ value=5,
151
+ step=1,
152
+ label="Top-K Retrieved"
153
+ )
154
+
155
+ search_btn = gr.Button("🔍 Retrieve", variant="primary")
156
+
157
+ with gr.Column(scale=2):
158
+ results_output = gr.Dataframe(
159
+ headers=["Rank", "Score", "Document", "ID"],
160
+ datatype=["number", "str", "str", "str"],
161
+ interactive=False,
162
+ wrap=True
163
+ )
164
+ status_box = gr.Textbox(label="Status", interactive=False)
165
+
166
+ search_btn.click(
167
+ fn=retrieve_documents,
168
+ inputs=[query_input, top_k_slider],
169
+ outputs=[results_output, status_box],
170
+ )
171
+
172
+
173
+ # =========================
174
+ # Run App
175
+ # =========================
176
  if __name__ == "__main__":
177
+ demo.launch(server_name="0.0.0.0", server_port=7860)