File size: 5,873 Bytes
554fa87
3f28fb7
6a1ebd8
 
 
 
 
3f28fb7
3bebfeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a1ebd8
79d5bf3
3bebfeb
 
 
 
79d5bf3
3bebfeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79d5bf3
 
3bebfeb
 
 
3f28fb7
79d5bf3
3bebfeb
 
 
 
 
79d5bf3
 
3bebfeb
 
6a1ebd8
3bebfeb
a2bdf76
 
3bebfeb
 
79d5bf3
3f28fb7
3bebfeb
79d5bf3
3f28fb7
3bebfeb
 
3f28fb7
3bebfeb
 
 
 
 
a2bdf76
3bebfeb
a2bdf76
3bebfeb
a2bdf76
3bebfeb
 
 
 
 
 
79d5bf3
3bebfeb
 
 
 
 
 
a2bdf76
3bebfeb
b28dcf5
a2bdf76
a06108e
a2bdf76
3bebfeb
 
 
 
 
3f28fb7
79d5bf3
3bebfeb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import requests
import os
from PyPDF2 import PdfReader
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Constants
CHUNK_SIZE = 300
MODEL_NAME = "all-MiniLM-L6-v2"
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
SERPER_API_KEY = os.getenv("SERPER_API_KEY")

# Load sentence embedding model
model = SentenceTransformer(MODEL_NAME)

# Global state
doc_chunks, doc_embeddings = [], []

# --- Text Extraction from PDF ---
def extract_pdf_text(file_obj):
    """Extracts and joins text from all pages of a PDF."""
    reader = PdfReader(file_obj)
    return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])

# --- Chunk Text ---
def split_text(text, size=CHUNK_SIZE):
    """Splits text into fixed-size word chunks."""
    words = text.split()
    return [" ".join(words[i:i + size]) for i in range(0, len(words), size)]

# --- File Upload Handling ---
def handle_file_upload(file):
    """Processes the uploaded PDF and caches its embeddings."""
    global doc_chunks, doc_embeddings
    if not file:
        return "⚠️ Please upload a file.", gr.update(visible=False)

    try:
        text = extract_pdf_text(file)
        doc_chunks = split_text(text)
        doc_embeddings = model.encode(doc_chunks)
        return f"βœ… Processed {len(doc_chunks)} chunks.", gr.update(visible=True, value=f"{len(doc_chunks)} chunks ready.")
    except Exception as e:
        return f"❌ Failed to process file: {e}", gr.update(visible=False)

# --- Semantic Retrieval ---
def get_top_chunks(query, k=3):
    """Finds top-k relevant chunks using cosine similarity."""
    query_emb = model.encode([query])
    sims = cosine_similarity(query_emb, doc_embeddings)[0]
    indices = np.argsort(sims)[::-1][:k]
    return "\n\n".join([doc_chunks[i] for i in indices])

# --- Call LLM via Together API ---
def call_together_ai(context, question):
    """Calls Mixtral LLM from Together API."""
    url = "https://api.together.xyz/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {TOGETHER_API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
        "messages": [
            {"role": "system", "content": "You are a helpful assistant answering from the given context."},
            {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"}
        ],
        "temperature": 0.7,
        "max_tokens": 512
    }
    res = requests.post(url, headers=headers, json=payload)
    return res.json()["choices"][0]["message"]["content"]

# --- Serper Web Search ---
def fetch_web_snippets(query):
    """Performs a web search via Serper API."""
    url = "https://google.serper.dev/search"
    headers = {"X-API-KEY": SERPER_API_KEY}
    res = requests.post(url, json={"q": query}, headers=headers).json()
    return "\n".join([
        f"πŸ”Ή [{r['title']}]({r['link']})\n{r['snippet']}" for r in res.get("organic", [])[:3]
    ])

# --- Main Chat Logic ---
def respond_to_query(question, source, history):
    """Handles query processing and LLM interaction."""
    if not question.strip():
        return history, ""

    history.append([question, None])

    try:
        if source == "🌐 Web Search":
            context = fetch_web_snippets(question)
            source_note = "🌐 Web Search"
        elif source == "πŸ“„ Uploaded File":
            if not doc_chunks:
                answer = "⚠️ Please upload a PDF document first."
                history[-1][1] = answer
                return history, ""
            context = get_top_chunks(question)
            source_note = "πŸ“„ Uploaded Document"
        else:
            history[-1][1] = "❌ Invalid knowledge source selected."
            return history, ""

        answer = call_together_ai(context, question)
        history[-1][1] = f"**{source_note}**\n\n{answer}"
        return history, ""
    except Exception as e:
        history[-1][1] = f"❌ Error: {e}"
        return history, ""

# --- Clear Chat ---
def clear_chat(): return []

# --- UI Design ---
css = """
.gradio-container { max-width: 1100px !important; margin: auto; }
h1, h2, h3 { text-align: center; }
"""

with gr.Blocks(css=css, theme=gr.themes.Soft(), title="πŸ” AI RAG Assistant") as demo:

    gr.HTML("<h1>πŸ€– AI Chat with RAG Capabilities</h1><h3>Ask questions from PDFs or real-time web search</h3>")

    with gr.Row():
        with gr.Column(scale=1):
            source = gr.Radio(["🌐 Web Search", "πŸ“„ Uploaded File"], label="Knowledge Source", value="🌐 Web Search")
            file = gr.File(label="Upload PDF", file_types=[".pdf"])
            status = gr.Textbox(label="Status", interactive=False)
            doc_info = gr.Textbox(label="Chunks Info", visible=False, interactive=False)

        with gr.Column(scale=2):
            chatbot = gr.Chatbot(label="Chat", height=500)
            query = gr.Textbox(placeholder="Type your question here...", lines=2)
            with gr.Row():
                send = gr.Button("Send")
                clear = gr.Button("Clear")

    with gr.Accordion("ℹ️ Info", open=False):
        gr.Markdown("- Web Search fetches latest online results\n- PDF mode retrieves answers from your document")

    gr.HTML("<div style='text-align:center; font-size:0.9em; color:gray;'>ipradeepsengarr</div>")

    # Bind events
    file.change(handle_file_upload, inputs=file, outputs=[status, doc_info])
    query.submit(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query])
    send.click(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query])
    clear.click(clear_chat, outputs=[chatbot])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)