Spaces:
Sleeping
Sleeping
| import os | |
| import threading | |
| import uvicorn | |
| import streamlit as st | |
| import requests | |
| from fastapi import FastAPI | |
| from langchain.document_loaders import DirectoryLoader, PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain.chains import create_retrieval_chain | |
| # β FastAPI Backend | |
| app = FastAPI(title="Vision Transformer Assistant", description="A FastAPI-powered AI assistant for deep learning.") | |
| # β Load Hugging Face Token π | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise ValueError("β οΈ HF_TOKEN is missing! Add it in Hugging Face Secrets.") | |
| # β Load Documents π | |
| loader = DirectoryLoader("./data/", glob="*.pdf", loader_cls=PyPDFLoader) | |
| docs = loader.load() | |
| # β Text Splitting π | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| texts = text_splitter.split_documents(docs) | |
| # β Vector Database π | |
| db = FAISS.from_documents(documents=texts, embedding=HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")) | |
| retriever = db.as_retriever() | |
| # β Load LLM π | |
| repo_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| llm = HuggingFaceEndpoint(repo_id=repo_id, token=HF_TOKEN, task="text-generation") | |
| # β Prompt Template β¨ | |
| prompt_temp = ChatPromptTemplate.from_template(""" | |
| You are an AI assistant specializing in deep learning, specifically Vision Transformers. | |
| <context> | |
| {context} | |
| <context> | |
| ### Instructions: | |
| - Extract relevant information only from retrieved documents. | |
| - Provide concise yet detailed responses. | |
| - Use LaTeX for equations when necessary. | |
| - Do not make up answers; respond with *'Information not available in retrieved documents.'* if needed. | |
| """) | |
| # β Create Retrieval Chain β‘ | |
| document_chain = create_stuff_documents_chain(llm, prompt_temp) | |
| retrieval_chain = create_retrieval_chain(retriever, document_chain) | |
| def get_response(query: str) -> str: | |
| """ | |
| Get AI-generated response based on query. | |
| """ | |
| response = retrieval_chain.invoke({"input": query}) | |
| answer = response.get("answer", "Error: No answer generated.") | |
| # Debugging Logs | |
| print(f"Query: {query} | Answer: {answer}") | |
| return answer | |
| def home(): | |
| return {"message": "Vision Transformer Assistant API is running π"} | |
| def get_answer(query: str): | |
| """ | |
| API endpoint to retrieve AI-generated responses. | |
| """ | |
| try: | |
| answer = get_response(query) | |
| return {"answer": answer} | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return {"answer": "Error occurred while processing the request."} | |
| # β Run FastAPI in a separate thread | |
| def run_fastapi(): | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| threading.Thread(target=run_fastapi, daemon=True).start() | |
| # β Streamlit UI | |
| st.set_page_config(page_title="Vision Transformer Assistant", page_icon="π€") | |
| st.title("Vision Transformer Assistant π€") | |
| st.markdown("Ask anything about deep learning and Vision Transformers!") | |
| FASTAPI_URL = "http://127.0.0.1:7860" # β Make sure this matches your FastAPI server | |
| # User input | |
| query = st.text_input("Enter your question:") | |
| if st.button("Get Answer"): | |
| if query: | |
| with st.spinner("Fetching answer..."): | |
| try: | |
| response = requests.get(f"{FASTAPI_URL}/query", params={"query": query}) | |
| # Check if response is valid | |
| if response.status_code == 200: | |
| answer = response.json().get("answer", "No answer found.") | |
| st.success("β Answer:") | |
| st.write(answer) | |
| else: | |
| st.error(f"β οΈ Error fetching answer. Status Code: {response.status_code}") | |
| st.write(response.text) # Debugging info | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"β οΈ Failed to connect to backend: {e}") | |