Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import os | |
| from PyPDF2 import PdfReader | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # Constants | |
| CHUNK_SIZE = 300 | |
| MODEL_NAME = "all-MiniLM-L6-v2" | |
| TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
| SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
| # Load sentence embedding model | |
| model = SentenceTransformer(MODEL_NAME) | |
| # Global state | |
| doc_chunks, doc_embeddings = [], [] | |
| # --- Text Extraction from PDF --- | |
| def extract_pdf_text(file_obj): | |
| """Extracts and joins text from all pages of a PDF.""" | |
| reader = PdfReader(file_obj) | |
| return "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
| # --- Chunk Text --- | |
| def split_text(text, size=CHUNK_SIZE): | |
| """Splits text into fixed-size word chunks.""" | |
| words = text.split() | |
| return [" ".join(words[i:i + size]) for i in range(0, len(words), size)] | |
| # --- File Upload Handling --- | |
| def handle_file_upload(file): | |
| """Processes the uploaded PDF and caches its embeddings.""" | |
| global doc_chunks, doc_embeddings | |
| if not file: | |
| return "β οΈ Please upload a file.", gr.update(visible=False) | |
| try: | |
| text = extract_pdf_text(file) | |
| doc_chunks = split_text(text) | |
| doc_embeddings = model.encode(doc_chunks) | |
| return f"β Processed {len(doc_chunks)} chunks.", gr.update(visible=True, value=f"{len(doc_chunks)} chunks ready.") | |
| except Exception as e: | |
| return f"β Failed to process file: {e}", gr.update(visible=False) | |
| # --- Semantic Retrieval --- | |
| def get_top_chunks(query, k=3): | |
| """Finds top-k relevant chunks using cosine similarity.""" | |
| query_emb = model.encode([query]) | |
| sims = cosine_similarity(query_emb, doc_embeddings)[0] | |
| indices = np.argsort(sims)[::-1][:k] | |
| return "\n\n".join([doc_chunks[i] for i in indices]) | |
| # --- Call LLM via Together API --- | |
| def call_together_ai(context, question): | |
| """Calls Mixtral LLM from Together API.""" | |
| url = "https://api.together.xyz/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {TOGETHER_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant answering from the given context."}, | |
| {"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"} | |
| ], | |
| "temperature": 0.7, | |
| "max_tokens": 512 | |
| } | |
| res = requests.post(url, headers=headers, json=payload) | |
| return res.json()["choices"][0]["message"]["content"] | |
| # --- Serper Web Search --- | |
| def fetch_web_snippets(query): | |
| """Performs a web search via Serper API.""" | |
| url = "https://google.serper.dev/search" | |
| headers = {"X-API-KEY": SERPER_API_KEY} | |
| res = requests.post(url, json={"q": query}, headers=headers).json() | |
| return "\n".join([ | |
| f"πΉ [{r['title']}]({r['link']})\n{r['snippet']}" for r in res.get("organic", [])[:3] | |
| ]) | |
| # --- Main Chat Logic --- | |
| def respond_to_query(question, source, history): | |
| """Handles query processing and LLM interaction.""" | |
| if not question.strip(): | |
| return history, "" | |
| history.append([question, None]) | |
| try: | |
| if source == "π Web Search": | |
| context = fetch_web_snippets(question) | |
| source_note = "π Web Search" | |
| elif source == "π Uploaded File": | |
| if not doc_chunks: | |
| answer = "β οΈ Please upload a PDF document first." | |
| history[-1][1] = answer | |
| return history, "" | |
| context = get_top_chunks(question) | |
| source_note = "π Uploaded Document" | |
| else: | |
| history[-1][1] = "β Invalid knowledge source selected." | |
| return history, "" | |
| answer = call_together_ai(context, question) | |
| history[-1][1] = f"**{source_note}**\n\n{answer}" | |
| return history, "" | |
| except Exception as e: | |
| history[-1][1] = f"β Error: {e}" | |
| return history, "" | |
| # --- Clear Chat --- | |
| def clear_chat(): return [] | |
| # --- UI Design --- | |
| css = """ | |
| .gradio-container { max-width: 1100px !important; margin: auto; } | |
| h1, h2, h3 { text-align: center; } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Soft(), title="π AI RAG Assistant") as demo: | |
| gr.HTML("<h1>π€ AI Chat with RAG Capabilities</h1><h3>Ask questions from PDFs or real-time web search</h3>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| source = gr.Radio(["π Web Search", "π Uploaded File"], label="Knowledge Source", value="π Web Search") | |
| file = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| status = gr.Textbox(label="Status", interactive=False) | |
| doc_info = gr.Textbox(label="Chunks Info", visible=False, interactive=False) | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Chat", height=500) | |
| query = gr.Textbox(placeholder="Type your question here...", lines=2) | |
| with gr.Row(): | |
| send = gr.Button("Send") | |
| clear = gr.Button("Clear") | |
| with gr.Accordion("βΉοΈ Info", open=False): | |
| gr.Markdown("- Web Search fetches latest online results\n- PDF mode retrieves answers from your document") | |
| gr.HTML("<div style='text-align:center; font-size:0.9em; color:gray;'>ipradeepsengarr</div>") | |
| # Bind events | |
| file.change(handle_file_upload, inputs=file, outputs=[status, doc_info]) | |
| query.submit(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query]) | |
| send.click(respond_to_query, inputs=[query, source, chatbot], outputs=[chatbot, query]) | |
| clear.click(clear_chat, outputs=[chatbot]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |