import os import shutil from datasets import load_dataset from dotenv import load_dotenv from langchain_groq import ChatGroq from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough import gradio as gr load_dotenv() groq_key = os.environ.get('groq_api_keys') # 1. Faster LLM Initialization llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_key, temperature=0) # 2. Faster Embedding Model # 'all-MiniLM-L6-v2' is 10x smaller and much faster than mxbai-large # but still excellent for astronomy abstracts. print("⌛ Loading embedding model...") embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") DB_DIR = "./chroma_db" # 3. SMART PERSISTENCE: Only index if the database doesn't exist if not os.path.exists(DB_DIR) or len(os.listdir(DB_DIR)) == 0: print("📦 First run: Indexing data (this will take a moment)...") # Load dataset only when needed ds = load_dataset("mehnaazasad/arxiv_astro_co_ga", split="test", streaming=True) # Take 50 samples efficiently data = [item["abstract"] for item in ds.take(50)] vectorstore = Chroma.from_texts( texts=data, embedding=embed_model, persist_directory=DB_DIR, collection_name="dataset_store" ) print("✅ Indexing complete.") else: print("🚀 Database found! Loading existing index...") vectorstore = Chroma( collection_name="dataset_store", embedding_function=embed_model, persist_directory=DB_DIR, ) retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) # 4. Optimized RAG Chain template = """You are astronomy expert. Use the context to answer. Context: {context} Question: {question} Answer:""" rag_prompt = PromptTemplate.from_template(template) rag_chain = ( {"context": retriever, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser() ) # 5. Optimized Streaming Function def rag_memory_stream(text): # Using the stream method directly for lower latency for chunk in rag_chain.stream(text): yield chunk # Gradio Interface demo = gr.Interface( title="⚡ Fast Astronomy AI", fn=rag_memory_stream, inputs=gr.Textbox(placeholder="Ask about galaxies..."), outputs="text", examples=['what are the characteristics of blue compact dwarf?', 'What is cold dark matter?'], allow_flagging="never" ) if __name__ == "__main__": demo.launch()