from fastapi import FastAPI from langchain.document_loaders import DirectoryLoader, PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain_huggingface import HuggingFaceEndpoint from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate from langchain.chains import create_retrieval_chain import os app = FastAPI() # Load Hugging Face Token HF_TOKEN = os.getenv("HF_TOKEN") # Load Documents 📂 loader = DirectoryLoader("./data/", glob="*.pdf", loader_cls=PyPDFLoader) docs = loader.load() # Text Splitting 📖 text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) texts = text_splitter.split_documents(docs) # Vector Database 🔍 db = FAISS.from_documents(documents=texts, embedding=HuggingFaceEmbeddings(model_name='BAAI/bge-base-en-v1.5')) retriever = db.as_retriever() # Load LLM 🚀 repo_id = "mistralai/Mistral-7B-Instruct-v0.3" llm = HuggingFaceEndpoint(repo_id=repo_id, token=HF_TOKEN, task="text-generation") # Prompt Template ✨ prompt_temp = ChatPromptTemplate.from_template(""" You are an AI assistant specializing in deep learning, specifically Vision Transformers. {context} ### Instructions: - Extract relevant information only from retrieved documents. - Provide concise yet detailed responses. - Use LaTeX for equations when necessary. - Do not make up answers; respond with *'Information not available in retrieved documents.'* if needed. """) document_chain = create_stuff_documents_chain(llm, prompt_temp) retrieval_chain = create_retrieval_chain(retriever, document_chain) @app.get("/") def home(): return {"message": "Vision Transformer Assistant API is running 🚀"} @app.get("/query/") def get_answer(query: str): response = retrieval_chain.invoke({'input': query}) return {"answer": response['answer']} # Run FastAPI on port 7860 if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)