| | import gradio as gr |
| | from transformers import AutoModel, AutoTokenizer |
| | import pandas as pd |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| |
|
| | |
| | MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" |
| | KURAL_EMBEDDINGS_FILE = "kural_embeds.pt" |
| | KURAL_DATA_FILE = "thirukural.tsv" |
| |
|
| | |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| | model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) |
| | except Exception as e: |
| | print(f"Error loading Transformer model: {e}") |
| | |
| |
|
| | try: |
| | kural_embeddings = torch.load(KURAL_EMBEDDINGS_FILE) |
| | except FileNotFoundError: |
| | print(f"Error: The file {KURAL_EMBEDDINGS_FILE} was not found.") |
| |
|
| | try: |
| | df = pd.read_csv(KURAL_DATA_FILE, sep='\t') |
| | except FileNotFoundError: |
| | print(f"Error: The file {KURAL_DATA_FILE} was not found.") |
| |
|
| | def get_detailed_instruct(query: str) -> str: |
| | return f'Instruct: Given a question, retrieve relevant Thirukkural couplets that are most relevant to, or answer the question\nQuery:{query}' |
| | |
| | def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: |
| | ''' |
| | Returns pooled embedding of last token from Qwen3 |
| | ''' |
| | left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| | if left_padding: |
| | return last_hidden_states[:, -1] |
| | else: |
| | sequence_lengths = attention_mask.sum(dim=1) - 1 |
| | batch_size = last_hidden_states.shape[0] |
| | return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] |
| |
|
| | def find_relevant_kurals(question): |
| | """ |
| | Finds the top 5 most relevant Kurals using cosine similarity. |
| | """ |
| | batch_dict = tokenizer([get_detailed_instruct(question)], max_length=128, padding=False, truncation=True, return_tensors='pt') |
| | outputs = model(**batch_dict) |
| | query_embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']).detach().cpu() |
| | |
| | |
| | all_embeddings = torch.cat((query_embedding,kural_embeddings), axis=0) |
| | all_embeddings = F.normalize(all_embeddings, p=2, dim=1) |
| | scores = all_embeddings[:1]@all_embeddings[1:].T |
| |
|
| | |
| | |
| | top_indices = torch.topk(scores[0,:], 3).indices.tolist() |
| | |
| | |
| | results = [] |
| | for i in top_indices: |
| | results.append({ |
| | "kural_ta": df.iloc[i].get("kural", "N/A"), |
| | "kural_eng": df.iloc[i].get("kural_eng", "N/A"), |
| | "chapter": df.iloc[i].get("chapter", "N/A"), |
| | "similarity": scores[0,i] |
| | }) |
| | return results |
| |
|
| | def rag_interface(question): |
| | """ |
| | The main function for the Gradio interface. |
| | """ |
| | if not question: |
| | return "Please enter a question." |
| | |
| | kurals = find_relevant_kurals(question) |
| | |
| | output = "" |
| | for kural in kurals: |
| | output += f"**Kural (Tamil):** {kural['kural_ta']}<br>" |
| | output += f"**Kural (English):** {kural['kural_eng']}<br>" |
| | output += f"**Chapter:** {kural['chapter']}<br>" |
| | output += f"**Similarity:** {kural['similarity']:.2f}\n\n---\n" |
| | |
| | return output |
| |
|
| | |
| | iface = gr.Interface( |
| | fn=rag_interface, |
| | inputs=gr.Textbox(lines=2, placeholder="Enter your question here:"), |
| | outputs="markdown", |
| | title="Kural for your question", |
| | description="Ask a vexing question and get 3 relevant Thirukural couplets using embedding-similarity based search.", |
| | flagging_mode='never' |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | iface.launch() |
| |
|