File size: 4,297 Bytes
12320d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import threading
import uvicorn
import streamlit as st
import requests
from fastapi import FastAPI
from langchain.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEndpoint
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain

# βœ… FastAPI Backend
app = FastAPI(title="Vision Transformer Assistant", description="A FastAPI-powered AI assistant for deep learning.")

# βœ… Load Hugging Face Token πŸ”‘
HF_TOKEN = os.getenv("HF_TOKEN")

if not HF_TOKEN:
    raise ValueError("⚠️ HF_TOKEN is missing! Add it in Hugging Face Secrets.")

# βœ… Load Documents πŸ“‚
loader = DirectoryLoader("./data/", glob="*.pdf", loader_cls=PyPDFLoader)
docs = loader.load()

# βœ… Text Splitting πŸ“–
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(docs)

# βœ… Vector Database πŸ”
db = FAISS.from_documents(documents=texts, embedding=HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5"))
retriever = db.as_retriever()

# βœ… Load LLM πŸš€
repo_id = "mistralai/Mistral-7B-Instruct-v0.3"
llm = HuggingFaceEndpoint(repo_id=repo_id, token=HF_TOKEN, task="text-generation")

# βœ… Prompt Template ✨
prompt_temp = ChatPromptTemplate.from_template("""

You are an AI assistant specializing in deep learning, specifically Vision Transformers.



<context>

{context}

<context>



### Instructions:

- Extract relevant information only from retrieved documents.

- Provide concise yet detailed responses.

- Use LaTeX for equations when necessary.

- Do not make up answers; respond with *'Information not available in retrieved documents.'* if needed.

""")

# βœ… Create Retrieval Chain ⚑
document_chain = create_stuff_documents_chain(llm, prompt_temp)
retrieval_chain = create_retrieval_chain(retriever, document_chain)

def get_response(query: str) -> str:
    """

    Get AI-generated response based on query.

    """
    response = retrieval_chain.invoke({"input": query})
    answer = response.get("answer", "Error: No answer generated.")
    
    # Debugging Logs
    print(f"Query: {query} | Answer: {answer}")  
    
    return answer

@app.get("/")
def home():
    return {"message": "Vision Transformer Assistant API is running πŸš€"}

@app.get("/query")
def get_answer(query: str):
    """

    API endpoint to retrieve AI-generated responses.

    """
    try:
        answer = get_response(query)
        return {"answer": answer}
    except Exception as e:
        print(f"Error: {e}")
        return {"answer": "Error occurred while processing the request."}

# βœ… Run FastAPI in a separate thread
def run_fastapi():
    uvicorn.run(app, host="0.0.0.0", port=7860)

threading.Thread(target=run_fastapi, daemon=True).start()

# βœ… Streamlit UI
st.set_page_config(page_title="Vision Transformer Assistant", page_icon="πŸ€–")
st.title("Vision Transformer Assistant πŸ€–")
st.markdown("Ask anything about deep learning and Vision Transformers!")

FASTAPI_URL = "http://127.0.0.1:7860"  # βœ… Make sure this matches your FastAPI server

# User input
query = st.text_input("Enter your question:")

if st.button("Get Answer"):
    if query:
        with st.spinner("Fetching answer..."):
            try:
                response = requests.get(f"{FASTAPI_URL}/query", params={"query": query})

                # Check if response is valid
                if response.status_code == 200:
                    answer = response.json().get("answer", "No answer found.")
                    st.success("βœ… Answer:")
                    st.write(answer)
                else:
                    st.error(f"⚠️ Error fetching answer. Status Code: {response.status_code}")
                    st.write(response.text)  # Debugging info

            except requests.exceptions.RequestException as e:
                st.error(f"⚠️ Failed to connect to backend: {e}")