File size: 3,685 Bytes
641f5e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0536153
641f5e1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from transformers import AutoModel, AutoTokenizer
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor

# --- Configuration ---
MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B"  # Placeholder for your model
KURAL_EMBEDDINGS_FILE = "kural_embeds.pt"
KURAL_DATA_FILE = "thirukural.tsv" # You'll need a CSV with the Kural text

# --- Load Resources ---
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
except Exception as e:
    print(f"Error loading Transformer model: {e}")
    # Handle model loading failure (e.g., exit or use a fallback)

try:
    kural_embeddings = torch.load(KURAL_EMBEDDINGS_FILE)
except FileNotFoundError:
    print(f"Error: The file {KURAL_EMBEDDINGS_FILE} was not found.")

try:
    df = pd.read_csv(KURAL_DATA_FILE, sep='\t')
except FileNotFoundError:
    print(f"Error: The file {KURAL_DATA_FILE} was not found.")

def get_detailed_instruct(query: str) -> str:
    return f'Instruct: Given a question, retrieve relevant Thirukkural couplets that are most relevant to, or answer the question\nQuery:{query}'
    
def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    '''
    Returns pooled embedding of last token from Qwen3
    '''
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

def find_relevant_kurals(question):
    """
    Finds the top 5 most relevant Kurals using cosine similarity.
    """
    batch_dict = tokenizer([get_detailed_instruct(question)], max_length=128, padding=False, truncation=True, return_tensors='pt')
    outputs = model(**batch_dict)
    query_embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']).detach().cpu()
    
    # Calculate similarities
    all_embeddings = torch.cat((query_embedding,kural_embeddings), axis=0)
    all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
    scores = all_embeddings[:1]@all_embeddings[1:].T

    
    # Get top 5 indices
    top_indices = torch.topk(scores[0,:], 3).indices.tolist()
    
    # Prepare results
    results = []
    for i in top_indices:
        results.append({
            "kural_ta": df.iloc[i].get("kural", "N/A"),
            "kural_eng": df.iloc[i].get("kural_eng", "N/A"),
            "chapter": df.iloc[i].get("chapter", "N/A"),
            "similarity": scores[0,i]
        })
    return results

def rag_interface(question):
    """
    The main function for the Gradio interface.
    """
    if not question:
        return "Please enter a question."
        
    kurals = find_relevant_kurals(question)
    
    output = ""
    for kural in kurals:
        output += f"**Kural (Tamil):** {kural['kural_ta']}<br>"
        output += f"**Kural (English):** {kural['kural_eng']}<br>"
        output += f"**Chapter:** {kural['chapter']}<br>"
        output += f"**Similarity:** {kural['similarity']:.2f}\n\n---\n"
        
    return output

# --- Gradio Interface ---
iface = gr.Interface(
    fn=rag_interface,
    inputs=gr.Textbox(lines=2, placeholder="Enter your question here:"),
    outputs="markdown",
    title="Kural for your question",
    description="Ask a vexing question and get 3 relevant Thirukural couplets using embedding-similarity based search.",
    flagging_mode='never'
)

if __name__ == "__main__":
    iface.launch()