Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import chromadb | |
| import gradio as gr | |
| from pypdf import PdfReader | |
| import docx | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| # ========================= | |
| # ๐ GROQ API (HF SECRET) | |
| # ========================= | |
| # Set your secret as "GROQ_API_KEY" in HF Space Settings โ Variables and secrets | |
| groq_client = Groq(api_key=os.getenv("Multi_doc")) | |
| # ========================= | |
| # ๐ LOAD DOCUMENTS | |
| # ========================= | |
| def load_pdf(path): | |
| reader = PdfReader(path) | |
| return "\n".join([p.extract_text() or "" for p in reader.pages]) | |
| def load_docx(path): | |
| doc = docx.Document(path) | |
| return "\n".join([p.text for p in doc.paragraphs]) | |
| def load_txt(path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| def load_document(path): | |
| ext = path.split(".")[-1].lower() | |
| if ext == "pdf": | |
| return load_pdf(path) | |
| if ext == "docx": | |
| return load_docx(path) | |
| if ext == "txt": | |
| return load_txt(path) | |
| raise ValueError(f"Unsupported file type: .{ext}") | |
| # ========================= | |
| # โ๏ธ CHUNKING | |
| # ========================= | |
| def chunk_text(text, size=400, overlap=80): | |
| words = text.split() | |
| chunks = [] | |
| i = 0 | |
| cid = 0 | |
| while i < len(words): | |
| chunks.append({ | |
| "id": cid, | |
| "text": " ".join(words[i:i + size]) | |
| }) | |
| i += size - overlap | |
| cid += 1 | |
| return chunks | |
| # ========================= | |
| # ๐ง EMBEDDINGS (LOCAL) | |
| # ========================= | |
| embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| def embed(texts): | |
| return embed_model.encode(texts, show_progress_bar=False).tolist() | |
| # ========================= | |
| # ๐๏ธ CHROMA DB | |
| # HF Spaces has a read-only root โ use /tmp for writable storage | |
| # ========================= | |
| chroma_client = chromadb.PersistentClient(path="/tmp/chroma_db") | |
| collection = chroma_client.get_or_create_collection("rag") | |
| # ========================= | |
| # ๐ PROCESS FILES | |
| # ========================= | |
| def process_files(files): | |
| if not files: | |
| return "โ ๏ธ No files uploaded." | |
| all_chunks = [] | |
| errors = [] | |
| for f in files: | |
| # Gradio on HF passes file path as a string or NamedString | |
| file_path = f if isinstance(f, str) else f.name | |
| if not file_path: | |
| continue | |
| try: | |
| text = load_document(file_path) | |
| if not text.strip(): | |
| errors.append(f"โ ๏ธ {os.path.basename(file_path)} appears empty.") | |
| continue | |
| chunks = chunk_text(text) | |
| for c in chunks: | |
| all_chunks.append({ | |
| "source": os.path.basename(file_path), | |
| "text": c["text"] | |
| }) | |
| except Exception as e: | |
| errors.append(f"โ Error reading {os.path.basename(file_path)}: {e}") | |
| if not all_chunks: | |
| return "\n".join(errors) if errors else "โ ๏ธ No content could be extracted." | |
| texts = [c["text"] for c in all_chunks] | |
| embeddings = embed(texts) | |
| collection.add( | |
| ids=[str(uuid.uuid4()) for _ in all_chunks], | |
| embeddings=embeddings, | |
| documents=texts, | |
| metadatas=[{"source": c["source"]} for c in all_chunks] | |
| ) | |
| result = f"โ Indexed {len(files)} file(s) โ {len(all_chunks)} chunks stored." | |
| if errors: | |
| result += "\n" + "\n".join(errors) | |
| return result | |
| # ========================= | |
| # ๐ RETRIEVAL | |
| # ========================= | |
| def retrieve(query, k=3): | |
| # Guard: collection might be empty | |
| count = collection.count() | |
| if count == 0: | |
| return [] | |
| k = min(k, count) # Can't retrieve more than what's stored | |
| q_emb = embed([query])[0] | |
| results = collection.query( | |
| query_embeddings=[q_emb], | |
| n_results=k | |
| ) | |
| docs = [] | |
| for i in range(len(results["documents"][0])): | |
| docs.append({ | |
| "text": results["documents"][0][i], | |
| "source": results["metadatas"][0][i]["source"] | |
| }) | |
| return docs | |
| # ========================= | |
| # ๐ค GROQ GENERATION | |
| # ========================= | |
| def generate(query): | |
| docs = retrieve(query) | |
| if not docs: | |
| return "โ ๏ธ No documents indexed yet. Please upload and process files first." | |
| context = "\n\n".join( | |
| [f"[{d['source']}]\n{d['text']}" for d in docs] | |
| ) | |
| prompt = f"""You are a strict RAG assistant. | |
| Answer ONLY from the context below. | |
| If the answer is not found in the context, say: "Not found in documents." | |
| CONTEXT: | |
| {context} | |
| QUESTION: | |
| {query} | |
| ANSWER:""" | |
| try: | |
| response = groq_client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.2, | |
| max_tokens=1024, | |
| ) | |
| answer = response.choices[0].message.content | |
| except Exception as e: | |
| return f"โ Groq API error: {e}" | |
| sources = "\n\n".join( | |
| [f"๐ **{d['source']}**\n{d['text'][:200]}โฆ" for d in docs] | |
| ) | |
| return f"{answer}\n\n---\n๐ **Sources:**\n{sources}" | |
| # ========================= | |
| # ๐ฌ CHAT FUNCTION | |
| # Gradio 5 uses {"role": ..., "content": ...} dicts, not tuples | |
| # ========================= | |
| def chat(message, history): | |
| if not message.strip(): | |
| return "", history | |
| reply = generate(message) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": reply}) | |
| return "", history | |
| # ========================= | |
| # ๐จ GRADIO UI | |
| # ========================= | |
| with gr.Blocks(title="Groq RAG Assistant") as app: | |
| gr.Markdown( | |
| """# ๐ง Groq RAG Assistant | |
| Upload your documents, then ask questions about them. | |
| Powered by **Groq LLaMA3** + **ChromaDB** + **sentence-transformers**. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ๐ Upload Documents") | |
| files = gr.File( | |
| file_count="multiple", | |
| file_types=[".pdf", ".docx", ".txt"], | |
| label="Upload PDF / DOCX / TXT" | |
| ) | |
| process_btn = gr.Button("๐ Process Files", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| process_btn.click(fn=process_files, inputs=files, outputs=status) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### ๐ฌ Ask Your Documents") | |
| # Gradio 5: type="messages" uses the new dict format | |
| chatbot = gr.Chatbot(height=480, type="messages") | |
| msg = gr.Textbox( | |
| placeholder="Ask a question about your documentsโฆ", | |
| label="Your question", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear Chat") | |
| submit_btn.click(fn=chat, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
| msg.submit(fn=chat, inputs=[msg, chatbot], outputs=[msg, chatbot]) | |
| clear_btn.click(fn=lambda: ([], ""), outputs=[chatbot, msg]) | |
| # ========================= | |
| # ๐ LAUNCH | |
| # ========================= | |
| if __name__ == "__main__": | |
| app.launch() |