File size: 2,632 Bytes
c55ec21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import shutil
from datasets import load_dataset
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr

load_dotenv()
groq_key = os.environ.get('groq_api_keys')

# 1. Faster LLM Initialization
llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_key, temperature=0)

# 2. Faster Embedding Model
# 'all-MiniLM-L6-v2' is 10x smaller and much faster than mxbai-large
# but still excellent for astronomy abstracts.
print("⌛ Loading embedding model...")
embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

DB_DIR = "./chroma_db"

# 3. SMART PERSISTENCE: Only index if the database doesn't exist
if not os.path.exists(DB_DIR) or len(os.listdir(DB_DIR)) == 0:
    print("📦 First run: Indexing data (this will take a moment)...")
    
    # Load dataset only when needed
    ds = load_dataset("mehnaazasad/arxiv_astro_co_ga", split="test", streaming=True)
    # Take 50 samples efficiently
    data = [item["abstract"] for item in ds.take(50)]
    
    vectorstore = Chroma.from_texts(
        texts=data,
        embedding=embed_model,
        persist_directory=DB_DIR,
        collection_name="dataset_store"
    )
    print("✅ Indexing complete.")
else:
    print("🚀 Database found! Loading existing index...")
    vectorstore = Chroma(
        collection_name="dataset_store",
        embedding_function=embed_model,
        persist_directory=DB_DIR,
    )

retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

# 4. Optimized RAG Chain
template = """You are astronomy expert. Use the context to answer. 
Context: {context}
Question: {question}
Answer:"""

rag_prompt = PromptTemplate.from_template(template)

rag_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | rag_prompt
    | llm
    | StrOutputParser()
)

# 5. Optimized Streaming Function
def rag_memory_stream(text):
    # Using the stream method directly for lower latency
    for chunk in rag_chain.stream(text):
        yield chunk

# Gradio Interface
demo = gr.Interface(
    title="⚡ Fast Astronomy AI",
    fn=rag_memory_stream,
    inputs=gr.Textbox(placeholder="Ask about galaxies..."),
    outputs="text",
    examples=['what are the characteristics of blue compact dwarf?', 'What is cold dark matter?'],
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()