Research / app.py
PranavReddy18's picture
Upload app.py
12320d4 verified
import os
import threading
import uvicorn
import streamlit as st
import requests
from fastapi import FastAPI
from langchain.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpoint
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
# βœ… FastAPI Backend
app = FastAPI(title="Vision Transformer Assistant", description="A FastAPI-powered AI assistant for deep learning.")
# βœ… Load Hugging Face Token πŸ”‘
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("⚠️ HF_TOKEN is missing! Add it in Hugging Face Secrets.")
# βœ… Load Documents πŸ“‚
loader = DirectoryLoader("./data/", glob="*.pdf", loader_cls=PyPDFLoader)
docs = loader.load()
# βœ… Text Splitting πŸ“–
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(docs)
# βœ… Vector Database πŸ”
db = FAISS.from_documents(documents=texts, embedding=HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5"))
retriever = db.as_retriever()
# βœ… Load LLM πŸš€
repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
llm = HuggingFaceEndpoint(repo_id=repo_id, token=HF_TOKEN, task="text-generation")
# βœ… Prompt Template ✨
prompt_temp = ChatPromptTemplate.from_template("""
You are an AI assistant specializing in deep learning, specifically Vision Transformers.
<context>
{context}
<context>
### Instructions:
- Extract relevant information only from retrieved documents.
- Provide concise yet detailed responses.
- Use LaTeX for equations when necessary.
- Do not make up answers; respond with *'Information not available in retrieved documents.'* if needed.
""")
# βœ… Create Retrieval Chain ⚑
document_chain = create_stuff_documents_chain(llm, prompt_temp)
retrieval_chain = create_retrieval_chain(retriever, document_chain)
def get_response(query: str) -> str:
"""
Get AI-generated response based on query.
"""
response = retrieval_chain.invoke({"input": query})
answer = response.get("answer", "Error: No answer generated.")
# Debugging Logs
print(f"Query: {query} | Answer: {answer}")
return answer
@app.get("/")
def home():
return {"message": "Vision Transformer Assistant API is running πŸš€"}
@app.get("/query")
def get_answer(query: str):
"""
API endpoint to retrieve AI-generated responses.
"""
try:
answer = get_response(query)
return {"answer": answer}
except Exception as e:
print(f"Error: {e}")
return {"answer": "Error occurred while processing the request."}
# βœ… Run FastAPI in a separate thread
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=7860)
threading.Thread(target=run_fastapi, daemon=True).start()
# βœ… Streamlit UI
st.set_page_config(page_title="Vision Transformer Assistant", page_icon="πŸ€–")
st.title("Vision Transformer Assistant πŸ€–")
st.markdown("Ask anything about deep learning and Vision Transformers!")
FASTAPI_URL = "http://127.0.0.1:7860" # βœ… Make sure this matches your FastAPI server
# User input
query = st.text_input("Enter your question:")
if st.button("Get Answer"):
if query:
with st.spinner("Fetching answer..."):
try:
response = requests.get(f"{FASTAPI_URL}/query", params={"query": query})
# Check if response is valid
if response.status_code == 200:
answer = response.json().get("answer", "No answer found.")
st.success("βœ… Answer:")
st.write(answer)
else:
st.error(f"⚠️ Error fetching answer. Status Code: {response.status_code}")
st.write(response.text) # Debugging info
except requests.exceptions.RequestException as e:
st.error(f"⚠️ Failed to connect to backend: {e}")