import gradio as gr import torch from transformers import AutoTokenizer, AutoModel import numpy as np from typing import List, Dict import pandas as pd import os from pinecone import Pinecone # ========================= # Retriever Class # ========================= class ParrotletRetriever: def __init__(self, model_name: str): """Initialize model and Pinecone client.""" self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Loading model on {self.device}...") # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=os.getenv("HF_TOKEN")) self.model = AutoModel.from_pretrained(model_name, token=os.getenv("HF_TOKEN")) self.model.to(self.device) self.model.eval() self.pinecone_namespace = os.environ.get("NAMESPACE") self.pinecone_client = Pinecone(api_key=os.environ.get("PINECONE_API_KEY")) self.pinecone_index = self.pinecone_client.Index(host=os.environ.get("PINECONE_HOST")) def mean_pooling(self, model_output, attention_mask): """Mean pooling for sentence embeddings.""" token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( input_mask_expanded.sum(1), min=1e-9 ) # -------------------------- # Text Encoder # -------------------------- def encode(self, texts: List[str]) -> np.ndarray: """Encode text into normalized embeddings.""" with torch.no_grad(): encoded_input = self.tokenizer( texts, padding=True, truncation=True, max_length=60, return_tensors="pt" ).to(self.device) model_output = self.model(**encoded_input) embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"]) embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) return embeddings.cpu().numpy() # -------------------------- # Pinecone Search # -------------------------- def search(self, query: str, top_k: int = 5) -> List[Dict]: """Search Pinecone index.""" query_vector = self.encode([query])[0] results = self.pinecone_index.query( namespace=self.pinecone_namespace, vector=query_vector.tolist(), top_k=top_k, include_metadata=True, include_values=False, ) docs = [] for i, match in enumerate(results["matches"]): metadata = match["metadata"] text = metadata.get("text") docs.append({ "Rank": i + 1, "Score": f"{match['score']:.2f}", "Document": text, "Snomed_id": metadata.get("concept_id") }) return docs # ========================= # Instantiate Retriever # ========================= MODEL_NAME = "ekacare/parrotlet-e" retriever = ParrotletRetriever(MODEL_NAME) def retrieve_documents(query: str, top_k: int = 5): """Perform retrieval and return results.""" if not query.strip(): return pd.DataFrame(), "Please enter a valid query." try: results = retriever.search(query, top_k) if not results: return pd.DataFrame(), "No results found." df = pd.DataFrame(results) status = f"✅ Retrieved top {len(results)} documents." return df except Exception as e: return pd.DataFrame(), f"⚠️ Error: {str(e)}" # ========================= # Gradio Interface (VERTICAL) # ========================= SAMPLE_QUERIES = [ "takhne me dard", "ghotyalu dard", "ghera aana", "vayiru vali", "छाती में दर्द", "talenovu", "వాంతులు" "ಕಾಮಲೆ", # jaundice "பேசுவது சிரமம்", # Dysphasia "Peshab Kartaana Jalan", # Scalding pain on urination "Kurunnal", "sunn hua", "moochithinaral", "মাথাব্যথা" ] with gr.Blocks(title="Parrotlet-e Retrieval", theme=gr.themes.Base()) as demo: # gr.Markdown( # """ # # **Multilingual Embedding Retrieval powered by EkaCare’s Parrotlet-e — the Indic Medical Entity Embedding Model.** # Parrotlet-e is a multilingual embedding model built to understand and represent medical terminology across India’s diverse languages and scripts, enabling seamless search and interoperability in healthcare data. # - 🔗 **Model on Hugging Face:** [Parrotlet-e](https://huggingface.co/ekacare/parrotlet-e) # - 📊 **Benchmarked on:** [Eka-IndicMTEB](https://huggingface.co/datasets/ekacare/Eka-IndicMTEB) # - 📰 **Read more on our blog:** [Introducing Parrotlet-e and Eka-IndicMTEB — Bridging India’s Multilingual Healthcare Gap](https://info.eka.care/services/introducing-parrotlet-e-and-eka-indicmteb-bridging-indias-multilingual-healthcare-gap) # """) gr.Markdown( """
A multilingual embedding model designed to represent Indian medical terminology across diverse languages and scripts — enabling seamless medical search, interoperability, and data understanding across India’s healthcare ecosystem.