Spaces:
Sleeping
Sleeping
File size: 4,297 Bytes
12320d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | import os
import threading
import uvicorn
import streamlit as st
import requests
from fastapi import FastAPI
from langchain.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpoint
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
# β
FastAPI Backend
app = FastAPI(title="Vision Transformer Assistant", description="A FastAPI-powered AI assistant for deep learning.")
# β
Load Hugging Face Token π
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("β οΈ HF_TOKEN is missing! Add it in Hugging Face Secrets.")
# β
Load Documents π
loader = DirectoryLoader("./data/", glob="*.pdf", loader_cls=PyPDFLoader)
docs = loader.load()
# β
Text Splitting π
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(docs)
# β
Vector Database π
db = FAISS.from_documents(documents=texts, embedding=HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5"))
retriever = db.as_retriever()
# β
Load LLM π
repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
llm = HuggingFaceEndpoint(repo_id=repo_id, token=HF_TOKEN, task="text-generation")
# β
Prompt Template β¨
prompt_temp = ChatPromptTemplate.from_template("""
You are an AI assistant specializing in deep learning, specifically Vision Transformers.
<context>
{context}
<context>
### Instructions:
- Extract relevant information only from retrieved documents.
- Provide concise yet detailed responses.
- Use LaTeX for equations when necessary.
- Do not make up answers; respond with *'Information not available in retrieved documents.'* if needed.
""")
# β
Create Retrieval Chain β‘
document_chain = create_stuff_documents_chain(llm, prompt_temp)
retrieval_chain = create_retrieval_chain(retriever, document_chain)
def get_response(query: str) -> str:
"""
Get AI-generated response based on query.
"""
response = retrieval_chain.invoke({"input": query})
answer = response.get("answer", "Error: No answer generated.")
# Debugging Logs
print(f"Query: {query} | Answer: {answer}")
return answer
@app.get("/")
def home():
return {"message": "Vision Transformer Assistant API is running π"}
@app.get("/query")
def get_answer(query: str):
"""
API endpoint to retrieve AI-generated responses.
"""
try:
answer = get_response(query)
return {"answer": answer}
except Exception as e:
print(f"Error: {e}")
return {"answer": "Error occurred while processing the request."}
# β
Run FastAPI in a separate thread
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=7860)
threading.Thread(target=run_fastapi, daemon=True).start()
# β
Streamlit UI
st.set_page_config(page_title="Vision Transformer Assistant", page_icon="π€")
st.title("Vision Transformer Assistant π€")
st.markdown("Ask anything about deep learning and Vision Transformers!")
FASTAPI_URL = "http://127.0.0.1:7860" # β
Make sure this matches your FastAPI server
# User input
query = st.text_input("Enter your question:")
if st.button("Get Answer"):
if query:
with st.spinner("Fetching answer..."):
try:
response = requests.get(f"{FASTAPI_URL}/query", params={"query": query})
# Check if response is valid
if response.status_code == 200:
answer = response.json().get("answer", "No answer found.")
st.success("β
Answer:")
st.write(answer)
else:
st.error(f"β οΈ Error fetching answer. Status Code: {response.status_code}")
st.write(response.text) # Debugging info
except requests.exceptions.RequestException as e:
st.error(f"β οΈ Failed to connect to backend: {e}")
|