sultan-hassan's picture
Create app.py
c55ec21 verified
import os
import shutil
from datasets import load_dataset
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
load_dotenv()
groq_key = os.environ.get('groq_api_keys')
# 1. Faster LLM Initialization
llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_key, temperature=0)
# 2. Faster Embedding Model
# 'all-MiniLM-L6-v2' is 10x smaller and much faster than mxbai-large
# but still excellent for astronomy abstracts.
print("βŒ› Loading embedding model...")
embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
DB_DIR = "./chroma_db"
# 3. SMART PERSISTENCE: Only index if the database doesn't exist
if not os.path.exists(DB_DIR) or len(os.listdir(DB_DIR)) == 0:
print("πŸ“¦ First run: Indexing data (this will take a moment)...")
# Load dataset only when needed
ds = load_dataset("mehnaazasad/arxiv_astro_co_ga", split="test", streaming=True)
# Take 50 samples efficiently
data = [item["abstract"] for item in ds.take(50)]
vectorstore = Chroma.from_texts(
texts=data,
embedding=embed_model,
persist_directory=DB_DIR,
collection_name="dataset_store"
)
print("βœ… Indexing complete.")
else:
print("πŸš€ Database found! Loading existing index...")
vectorstore = Chroma(
collection_name="dataset_store",
embedding_function=embed_model,
persist_directory=DB_DIR,
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# 4. Optimized RAG Chain
template = """You are astronomy expert. Use the context to answer.
Context: {context}
Question: {question}
Answer:"""
rag_prompt = PromptTemplate.from_template(template)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
# 5. Optimized Streaming Function
def rag_memory_stream(text):
# Using the stream method directly for lower latency
for chunk in rag_chain.stream(text):
yield chunk
# Gradio Interface
demo = gr.Interface(
title="⚑ Fast Astronomy AI",
fn=rag_memory_stream,
inputs=gr.Textbox(placeholder="Ask about galaxies..."),
outputs="text",
examples=['what are the characteristics of blue compact dwarf?', 'What is cold dark matter?'],
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()