ds-EkaCare commited on
Commit
628bda0
·
verified ·
1 Parent(s): 09151f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import numpy as np
5
+ from typing import List, Tuple
6
+ import pandas as pd
7
+
8
+ # Model configuration
9
+ MODEL_NAME = "ekacare/parrotlet-e"
10
+
11
+ class ParrotletEmbedder:
12
+ def __init__(self, model_name: str):
13
+ """Initialize the Parrotlet-E model and tokenizer."""
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ print(f"Loading model on {self.device}...")
16
+
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
19
+ self.model.to(self.device)
20
+ self.model.eval()
21
+
22
+ def mean_pooling(self, model_output, attention_mask):
23
+ """Perform mean pooling on model output."""
24
+ token_embeddings = model_output[0]
25
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
26
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
27
+
28
+ def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
29
+ """Encode texts into embeddings."""
30
+ all_embeddings = []
31
+
32
+ with torch.no_grad():
33
+ for i in range(0, len(texts), batch_size):
34
+ batch_texts = texts[i:i + batch_size]
35
+
36
+ encoded_input = self.tokenizer(
37
+ batch_texts,
38
+ padding=True,
39
+ truncation=True,
40
+ max_length=512,
41
+ return_tensors='pt'
42
+ ).to(self.device)
43
+
44
+ model_output = self.model(**encoded_input)
45
+ embeddings = self.mean_pooling(model_output, encoded_input['attention_mask'])
46
+
47
+ # Normalize embeddings
48
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
49
+ all_embeddings.append(embeddings.cpu().numpy())
50
+
51
+ return np.vstack(all_embeddings)
52
+
53
+ def compute_similarity(self, text1: str, text2: str) -> float:
54
+ """Compute cosine similarity between two texts."""
55
+ embeddings = self.encode([text1, text2])
56
+ similarity = np.dot(embeddings[0], embeddings[1])
57
+ return float(similarity)
58
+
59
+ # Initialize model
60
+ embedder = ParrotletEmbedder(MODEL_NAME)
61
+
62
+ def compute_pairwise_similarity(query: str, documents: str) -> Tuple[pd.DataFrame, str]:
63
+ """
64
+ Compute similarity between query and multiple documents.
65
+
66
+ Args:
67
+ query: The query text
68
+ documents: Documents separated by newlines
69
+
70
+ Returns:
71
+ DataFrame with documents and similarity scores, and status message
72
+ """
73
+ if not query.strip():
74
+ return pd.DataFrame(), "⚠️ Please enter a query text."
75
+
76
+ if not documents.strip():
77
+ return pd.DataFrame(), "⚠️ Please enter at least one document."
78
+
79
+ # Split documents by newlines and filter empty lines
80
+ doc_list = [doc.strip() for doc in documents.split('\n') if doc.strip()]
81
+
82
+ if len(doc_list) == 0:
83
+ return pd.DataFrame(), "⚠️ No valid documents found."
84
+
85
+ try:
86
+ # Encode query and documents
87
+ query_embedding = embedder.encode([query])
88
+ doc_embeddings = embedder.encode(doc_list)
89
+
90
+ # Compute similarities
91
+ similarities = np.dot(doc_embeddings, query_embedding.T).flatten()
92
+
93
+ # Create results dataframe
94
+ results = pd.DataFrame({
95
+ 'Rank': range(1, len(doc_list) + 1),
96
+ 'Document': doc_list,
97
+ 'Similarity Score': similarities
98
+ })
99
+
100
+ # Sort by similarity
101
+ results = results.sort_values('Similarity Score', ascending=False).reset_index(drop=True)
102
+ results['Rank'] = range(1, len(results) + 1)
103
+
104
+ status = f"✅ Successfully computed similarities for {len(doc_list)} documents."
105
+ return results, status
106
+
107
+ except Exception as e:
108
+ return pd.DataFrame(), f"❌ Error: {str(e)}"
109
+
110
+ def compute_single_similarity(text1: str, text2: str) -> Tuple[str, str]:
111
+ """
112
+ Compute similarity between two texts.
113
+
114
+ Args:
115
+ text1: First text
116
+ text2: Second text
117
+
118
+ Returns:
119
+ Similarity score and status message
120
+ """
121
+ if not text1.strip() or not text2.strip():
122
+ return "", "⚠️ Please enter both texts."
123
+
124
+ try:
125
+ similarity = embedder.compute_similarity(text1, text2)
126
+ score_display = f"### Similarity Score: {similarity:.4f}"
127
+ status = "✅ Similarity computed successfully."
128
+ return score_display, status
129
+
130
+ except Exception as e:
131
+ return "", f"❌ Error: {str(e)}"
132
+
133
+ # Create Gradio interface
134
+ with gr.Blocks(title="Parrotlet-e: Indic Medical Embedding Model", theme=gr.themes.Soft()) as demo:
135
+
136
+ with gr.Tab("Query-Document Matching"):
137
+ gr.Markdown("""
138
+ ### 📄 Semantic Search""")
139
+
140
+ with gr.Row():
141
+ with gr.Column():
142
+ query_input = gr.Textbox(
143
+ label="term1",
144
+ placeholder="",
145
+ lines=1
146
+ )
147
+ documents_input = gr.Textbox(
148
+ label="term2",
149
+ placeholder="",
150
+ lines=1
151
+ )
152
+ search_btn = gr.Button("🔍 Search", variant="primary")
153
+
154
+ with gr.Column():
155
+ search_output = gr.Dataframe(
156
+ label="Results",
157
+ headers=["Rank", "Document", "Similarity Score"],
158
+ datatype=["number", "str", "number"],
159
+ wrap=True
160
+ )
161
+ search_status = gr.Textbox(label="Status", interactive=False)
162
+
163
+ search_btn.click(
164
+ fn=compute_pairwise_similarity,
165
+ inputs=[query_input, documents_input],
166
+ outputs=[search_output, search_status]
167
+ )
168
+
169
+ with gr.Tab("Pairwise Similarity"):
170
+ gr.Markdown("""
171
+ ### 🔗 Compare Two Texts
172
+ Compute semantic similarity between any two medical texts (score ranges from -1 to 1).
173
+ """)
174
+
175
+ with gr.Row():
176
+ with gr.Column():
177
+ text1_input = gr.Textbox(
178
+ label="Text 1",
179
+ placeholder="Enter first text...\nExample: हृदय रोग के लक्षण",
180
+ lines=5
181
+ )
182
+ text2_input = gr.Textbox(
183
+ label="Text 2",
184
+ placeholder="Enter second text...\nExample: छाती में दर्द और सांस फूलना",
185
+ lines=5
186
+ )
187
+ similarity_btn = gr.Button("⚡ Calculate Similarity", variant="primary")
188
+
189
+ with gr.Column():
190
+ similarity_output = gr.Markdown(label="Similarity Score")
191
+ similarity_status = gr.Textbox(label="Status", interactive=False)
192
+
193
+ similarity_btn.click(
194
+ fn=compute_single_similarity,
195
+ inputs=[text1_input, text2_input],
196
+ outputs=[similarity_output, similarity_status]
197
+ )
198
+
199
+ # Launch the app
200
+ if __name__ == "__main__":
201
+ demo.launch(share=True)