Spaces:
Runtime error
Runtime error
| import os | |
| import shutil | |
| from datasets import load_dataset | |
| from dotenv import load_dotenv | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_chroma import Chroma | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| import gradio as gr | |
| load_dotenv() | |
| groq_key = os.environ.get('groq_api_keys') | |
| # 1. Faster LLM Initialization | |
| llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_key, temperature=0) | |
| # 2. Faster Embedding Model | |
| # 'all-MiniLM-L6-v2' is 10x smaller and much faster than mxbai-large | |
| # but still excellent for astronomy abstracts. | |
| print("β Loading embedding model...") | |
| embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| DB_DIR = "./chroma_db" | |
| # 3. SMART PERSISTENCE: Only index if the database doesn't exist | |
| if not os.path.exists(DB_DIR) or len(os.listdir(DB_DIR)) == 0: | |
| print("π¦ First run: Indexing data (this will take a moment)...") | |
| # Load dataset only when needed | |
| ds = load_dataset("mehnaazasad/arxiv_astro_co_ga", split="test", streaming=True) | |
| # Take 50 samples efficiently | |
| data = [item["abstract"] for item in ds.take(50)] | |
| vectorstore = Chroma.from_texts( | |
| texts=data, | |
| embedding=embed_model, | |
| persist_directory=DB_DIR, | |
| collection_name="dataset_store" | |
| ) | |
| print("β Indexing complete.") | |
| else: | |
| print("π Database found! Loading existing index...") | |
| vectorstore = Chroma( | |
| collection_name="dataset_store", | |
| embedding_function=embed_model, | |
| persist_directory=DB_DIR, | |
| ) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| # 4. Optimized RAG Chain | |
| template = """You are astronomy expert. Use the context to answer. | |
| Context: {context} | |
| Question: {question} | |
| Answer:""" | |
| rag_prompt = PromptTemplate.from_template(template) | |
| rag_chain = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | rag_prompt | |
| | llm | |
| | StrOutputParser() | |
| ) | |
| # 5. Optimized Streaming Function | |
| def rag_memory_stream(text): | |
| # Using the stream method directly for lower latency | |
| for chunk in rag_chain.stream(text): | |
| yield chunk | |
| # Gradio Interface | |
| demo = gr.Interface( | |
| title="β‘ Fast Astronomy AI", | |
| fn=rag_memory_stream, | |
| inputs=gr.Textbox(placeholder="Ask about galaxies..."), | |
| outputs="text", | |
| examples=['what are the characteristics of blue compact dwarf?', 'What is cold dark matter?'], | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |