Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| from typing import List, Dict | |
| import pandas as pd | |
| import os | |
| from pinecone import Pinecone | |
| # ========================= | |
| # Retriever Class | |
| # ========================= | |
| class ParrotletRetriever: | |
| def __init__(self, model_name: str): | |
| """Initialize model and Pinecone client.""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"🚀 Loading model on {self.device}...") | |
| # Load tokenizer and model | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN")) | |
| self.model = AutoModel.from_pretrained(model_name, token=os.getenv("HF_TOKEN")) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.pinecone_namespace = os.environ.get("NAMESPACE") | |
| self.pinecone_client = Pinecone(api_key=os.environ.get("PINECONE_API_KEY")) | |
| self.pinecone_index = self.pinecone_client.Index(host=os.environ.get("PINECONE_HOST")) | |
| def mean_pooling(self, model_output, attention_mask): | |
| """Mean pooling for sentence embeddings.""" | |
| token_embeddings = model_output[0] | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( | |
| input_mask_expanded.sum(1), min=1e-9 | |
| ) | |
| # -------------------------- | |
| # Text Encoder | |
| # -------------------------- | |
| def encode(self, texts: List[str]) -> np.ndarray: | |
| """Encode text into normalized embeddings.""" | |
| with torch.no_grad(): | |
| encoded_input = self.tokenizer( | |
| texts, padding=True, truncation=True, max_length=60, return_tensors="pt" | |
| ).to(self.device) | |
| model_output = self.model(**encoded_input) | |
| embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"]) | |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
| return embeddings.cpu().numpy() | |
| # -------------------------- | |
| # Pinecone Search | |
| # -------------------------- | |
| def search(self, query: str, top_k: int = 5) -> List[Dict]: | |
| """Search Pinecone index.""" | |
| query_vector = self.encode([query])[0] | |
| results = self.pinecone_index.query( | |
| namespace=self.pinecone_namespace, | |
| vector=query_vector.tolist(), | |
| top_k=top_k, | |
| include_metadata=True, | |
| include_values=False, | |
| ) | |
| docs = [] | |
| for i, match in enumerate(results["matches"]): | |
| metadata = match["metadata"] | |
| text = metadata.get("text") | |
| docs.append({ | |
| "Rank": i + 1, | |
| "Score": f"{match['score']:.2f}", | |
| "Document": text, | |
| "Snomed_id": metadata.get("concept_id") | |
| }) | |
| return docs | |
| # ========================= | |
| # Instantiate Retriever | |
| # ========================= | |
| MODEL_NAME = "ekacare/parrotlet-e" | |
| retriever = ParrotletRetriever(MODEL_NAME) | |
| def retrieve_documents(query: str, top_k: int = 5): | |
| """Perform retrieval and return results.""" | |
| if not query.strip(): | |
| return pd.DataFrame(), "Please enter a valid query." | |
| try: | |
| results = retriever.search(query, top_k) | |
| if not results: | |
| return pd.DataFrame(), "No results found." | |
| df = pd.DataFrame(results) | |
| status = f"✅ Retrieved top {len(results)} documents." | |
| return df | |
| except Exception as e: | |
| return pd.DataFrame(), f"⚠️ Error: {str(e)}" | |
| # ========================= | |
| # Gradio Interface (VERTICAL) | |
| # ========================= | |
| SAMPLE_QUERIES = [ | |
| "takhne me dard", | |
| "ghotyalu dard", | |
| "ghera aana", | |
| "vayiru vali", | |
| "छाती में दर्द", | |
| "talenovu", | |
| "వాంతులు" | |
| "ಕಾಮಲೆ", # jaundice | |
| "பேசுவது சிரமம்", # Dysphasia | |
| "Peshab Kartaana Jalan", # Scalding pain on urination | |
| "Kurunnal", | |
| "sunn hua", | |
| "moochithinaral", | |
| "মাথাব্যথা" | |
| ] | |
| with gr.Blocks(title="Parrotlet-e Retrieval", theme=gr.themes.Base()) as demo: | |
| # gr.Markdown( | |
| # """ | |
| # # **Multilingual Embedding Retrieval powered by EkaCare’s Parrotlet-e — the Indic Medical Entity Embedding Model.** | |
| # Parrotlet-e is a multilingual embedding model built to understand and represent medical terminology across India’s diverse languages and scripts, enabling seamless search and interoperability in healthcare data. | |
| # - 🔗 **Model on Hugging Face:** [Parrotlet-e](https://huggingface.co/ekacare/parrotlet-e) | |
| # - 📊 **Benchmarked on:** [Eka-IndicMTEB](https://huggingface.co/datasets/ekacare/Eka-IndicMTEB) | |
| # - 📰 **Read more on our blog:** [Introducing Parrotlet-e and Eka-IndicMTEB — Bridging India’s Multilingual Healthcare Gap](https://info.eka.care/services/introducing-parrotlet-e-and-eka-indicmteb-bridging-indias-multilingual-healthcare-gap) | |
| # """) | |
| gr.Markdown( | |
| """ | |
| <div style="text-align: center; margin-top: 10px; margin-bottom: 15px;"> | |
| <h2 style="color:#1f2937; font-size: 26px; margin-bottom: 6px;"> | |
| 🦜 <b>Parrotlet-e</b> — Indic Medical Entity Embedding Model | |
| </h2> | |
| <p style="font-size:16px; color:#4b5563; max-width:700px; margin: 0 auto;"> | |
| A multilingual embedding model designed to represent Indian medical terminology across diverse languages and scripts — | |
| enabling seamless medical search, interoperability, and data understanding across India’s healthcare ecosystem. | |
| </p> | |
| </div> | |
| <div style="text-align: left; margin-top: 15px; font-size:15px;"> | |
| <ul style="list-style: none; padding-left: 0;"> | |
| <li>🔗 <b>Model on Hugging Face:</b> | |
| <a href="https://huggingface.co/ekacare/parrotlet-e" target="_blank" style="color:#2563eb;">Parrotlet-e</a> | |
| </li> | |
| <li>📊 <b>Benchmarked on:</b> | |
| <a href="https://huggingface.co/datasets/ekacare/Eka-IndicMTEB" target="_blank" style="color:#2563eb;">Eka-IndicMTEB</a> | |
| </li> | |
| <li>📰 <b>Read more on our blog:</b> | |
| <a href="https://info.eka.care/services/introducing-parrotlet-e-and-eka-indicmteb-bridging-indias-multilingual-healthcare-gap" | |
| target="_blank" style="color:#2563eb;"> | |
| Introducing Parrotlet-e and Eka-IndicMTEB — Bridging India’s Multilingual Healthcare Gap | |
| </a> | |
| </li> | |
| </ul> | |
| </div> | |
| <hr style="margin-top:25px; margin-bottom:10px;"> | |
| """ | |
| ) | |
| # ---- Input Section ---- | |
| with gr.Group(): | |
| query_input = gr.Textbox( | |
| label="Enter a medical term (not sentences in any language)", | |
| placeholder="Type your query here...", | |
| lines=1, | |
| ) | |
| examples = gr.Examples( | |
| examples=SAMPLE_QUERIES, | |
| inputs=query_input, | |
| label="Example Queries", | |
| examples_per_page=len(SAMPLE_QUERIES) | |
| ) | |
| top_k_input = gr.Number( | |
| label="Number of results (K)", | |
| value=5, | |
| precision=0, | |
| interactive=True | |
| ) | |
| search_btn = gr.Button("retrieve", variant="primary") | |
| # ---- Output Section ---- | |
| with gr.Group(): | |
| results_output = gr.Dataframe( | |
| headers=["Rank", "Score", "Term", "Snomed_id"], | |
| datatype=["number", "str", "str", "str"], | |
| interactive=False, | |
| wrap=True, | |
| label="Search Results" | |
| ) | |
| # status_box = gr.Textbox(label="Status", interactive=False) | |
| # ---- Function Binding ---- | |
| search_btn.click( | |
| fn=retrieve_documents, | |
| inputs=[query_input, top_k_input], | |
| outputs=[results_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |