File size: 5,773 Bytes
8ac52c9
 
 
 
cdba4d2
8ac52c9
 
cdba4d2
8ac52c9
0a7a7a3
8ac52c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a7a7a3
8ac52c9
 
 
0a7a7a3
8ac52c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdba4d2
8ac52c9
 
 
 
 
 
 
 
 
 
 
0a7a7a3
8ac52c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import numpy as np
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
import spaces

# For embeddings using transformers models
@spaces.GPU
def get_embeddings(texts, model, tokenizer):
    encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
    
    # Mean pooling for sentence embeddings
    token_embeddings = model_output.last_hidden_state
    attention_mask = encoded_input['attention_mask']
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return embeddings.cpu().numpy()

# Calculate cosine similarity using numpy
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

# Load models
def load_models():
    # Embedding model
    embed_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    embed_model = AutoModel.from_pretrained("bert-base-uncased")
    
    # Generation model
    gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
    generator = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
    
    return embed_model, embed_tokenizer, generator, gen_tokenizer

# Process uploaded text files
def process_documents(files):
    documents = []
    for file in files:
        with open(file.name, "r", encoding="utf-8") as f:
            content = f.read()
            # Simple document chunking by paragraphs
            paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
            documents.extend(paragraphs)
    return documents

# Create index from documents
def create_index(model, tokenizer, documents):
    if not documents:
        return None, None
    
    # Create embeddings
    embeddings = get_embeddings(documents, model, tokenizer)
    return embeddings, documents

# Retrieve relevant documents
def retrieve(query, embeddings, documents, model, tokenizer, top_k=3):
    if embeddings is None or documents is None:
        return []
    
    # Encode query
    query_embedding = get_embeddings([query], model, tokenizer)[0]
    
    # Calculate similarities
    similarities = [cosine_similarity(query_embedding, doc_embed) for doc_embed in embeddings]
    
    # Get top k indices
    top_indices = np.argsort(similarities)[-top_k:][::-1]
    
    # Return relevant documents
    return [documents[idx] for idx in top_indices]

# Generate answer
@spaces.GPU
def generate_answer(query, context, tokenizer, generator):
    if not context:
        return "No documents have been uploaded yet. Please upload some text files first."
    
    # Combine context
    combined_context = " ".join(context)
    
    # Create prompt
    prompt = f"Context: {combined_context}\n\nQuestion: {query}\n\nAnswer:"
    
    # Generate answer
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    
    with torch.no_grad():
        outputs = generator.generate(
            **inputs,
            max_length=256,
            num_beams=4,
            temperature=0.7,
            top_p=0.9,
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# RAG pipeline
def rag_pipeline(query, files):
    try:
        global embed_model, embed_tokenizer, generator, gen_tokenizer, doc_embeddings, indexed_documents
        
        if not files:
            return "Please upload some text files first."
        
        # Process documents
        documents = process_documents(files)
        
        # Create embeddings
        doc_embeddings, indexed_documents = create_index(embed_model, embed_tokenizer, documents)
        
        # Retrieve relevant context
        context = retrieve(query, doc_embeddings, indexed_documents, embed_model, embed_tokenizer)
        
        # Generate answer
        answer = generate_answer(query, context, gen_tokenizer, generator)
        
        return answer
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Initialize global variables
embed_model, embed_tokenizer, generator, gen_tokenizer = load_models()
doc_embeddings, indexed_documents = None, None

# Gradio interface
with gr.Blocks(title="RAG Demo") as demo:
    gr.Markdown("# πŸ“„πŸ” Retrieval-Augmented Generation (RAG) Demo")
    gr.Markdown("Upload text files and ask questions about their content.")
    
    with gr.Row():
        with gr.Column(scale=1):
            file_output = gr.File(
                file_count="multiple",
                label="Upload Text Files (.txt)",
                file_types=[".txt"],
            )
        
        with gr.Column(scale=2):
            query_input = gr.Textbox(
                label="Your Question",
                placeholder="Ask a question about the uploaded documents...",
            )
            submit_btn = gr.Button("Get Answer", variant="primary")
            answer_output = gr.Textbox(label="Answer", lines=10)
    
    submit_btn.click(
        rag_pipeline,
        inputs=[query_input, file_output],
        outputs=answer_output,
    )
    
    gr.Markdown(
        """
        ## How it works
        1. Upload one or more `.txt` files
        2. Ask a question related to the content
        3. The system will:
           - Create embeddings using BERT
           - Find similar passages using vector similarity
           - Retrieve relevant context for your query
           - Generate an answer using `flan-t5-base`
        
        Built with πŸ€— Hugging Face's models and Gradio
        """
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()