File size: 2,458 Bytes
69ba0b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa652bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from dotenv import load_dotenv
from fastapi import FastAPI
from pydantic import BaseModel

# --- Imports for LCEL ---
# We replace the create_..._chain imports with these building blocks
from operator import itemgetter  # A handy tool to get a value from a dict
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
# --- End of new imports ---

from langchain_community.vectorstores import FAISS
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_core.prompts import PromptTemplate

# --- 1. SETUP (Your code is perfect) ---
load_dotenv()
api_key = os.getenv("GEMINI_API_KEY")

app = FastAPI()

# Initialize your models and retriever
embeddings = GoogleGenerativeAIEmbeddings(
    model="gemini-embedding-001", google_api_key=api_key
)
vector_store = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
retriever = vector_store.as_retriever()

llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", google_api_key=api_key)

# Your prompt template
template = """
You are a helpful AI assistant. Answer the user's question based on the
following context. If you don't know the answer, just say "I don't know."

Context: {context}
Question: {input}
"""
prompt = PromptTemplate.from_template(template)


# --- 2. BUILD YOUR CHAIN WITH LCEL ---

# This is the equivalent of 'create_stuff_documents_chain'
# It "stuffs" the context and input into the prompt, then calls the model.
document_chain = prompt | llm | StrOutputParser()

# This is the equivalent of 'create_retrieval_chain'
# It defines the full RAG process.
retrieval_chain = RunnableParallel(
    # "context": Run the retriever on the user's "input"
    context=(itemgetter("input") | retriever), 
    # "input": Pass the user's "input" straight through
    input=itemgetter("input")
) | document_chain # Pipe the resulting {context, input} dict into our document_chain


# --- 3. YOUR API (No changes needed) ---

@app.get("/")
def read_root():
    return {"Hello": "Welcome to the Gemini RAG API. Go to /docs to test."}

class Query(BaseModel):
    query: str

@app.post("/ask")
async def ask_query(query: Query):
    # Use the .invoke() method on your new LCEL chain
    # It expects a dictionary matching the 'itemgetter' keys
    response = retrieval_chain.invoke({"input": query.query})
    
    return {"answer": response}