Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # === Globals === | |
| llm = None | |
| embeddings = None | |
| vectorstore = None | |
| retriever = None | |
| quiz_chain = None | |
| grade_chain = None | |
| class QuizRequest(BaseModel): | |
| question: str | |
| class GradeRequest(BaseModel): | |
| question: str # string of Q/A pairs | |
| def load_components(): | |
| global llm, embeddings, vectorstore, retriever, quiz_chain, grade_chain | |
| try: | |
| api_key = os.getenv("api_key") | |
| if not api_key: | |
| logger.error("API_KEY environment variable is not set or empty.") | |
| raise RuntimeError("API_KEY environment variable is not set or empty.") | |
| logger.info("API_KEY is set.") | |
| # 1) Init LLM & Embeddings | |
| llm = ChatGroq( | |
| model="meta-llama/llama-4-scout-17b-16e-instruct", | |
| temperature=0, | |
| max_tokens=1024, | |
| api_key=api_key, | |
| ) | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="intfloat/multilingual-e5-large", | |
| model_kwargs={"device": "cpu"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| ) | |
| # 2) Load FAISS indexes | |
| for zip_name, dir_name in [("faiss_index.zip", "faiss_index"), ("faiss_index(1).zip", "faiss_index_extra")]: | |
| if not os.path.exists(dir_name): | |
| with zipfile.ZipFile(zip_name, 'r') as z: | |
| z.extractall(dir_name) | |
| logger.info(f"Unzipped {zip_name} to {dir_name}.") | |
| else: | |
| logger.info(f"Directory {dir_name} already exists.") | |
| vs1 = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True) | |
| logger.info("FAISS index 1 loaded.") | |
| vs2 = FAISS.load_local("faiss_index_extra", embeddings, allow_dangerous_deserialization=True) | |
| logger.info("FAISS index 2 loaded.") | |
| vs1.merge_from(vs2) | |
| vectorstore = vs1 | |
| logger.info("Merged FAISS indexes into a single vectorstore.") | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) | |
| # Quiz generation chain | |
| quiz_prompt = PromptTemplate( | |
| template=""" | |
| Generate a quiz on the topic "{question}" using **only** the information in the "Retrieved context". | |
| Include clear questions and multiple-choice options (A, B, C, D). Also provide the answers of the questions with them. | |
| If context is insufficient, reply with "I don't know". | |
| Retrieved context: | |
| {context} | |
| Quiz topic: | |
| {question} | |
| Quiz: | |
| """, | |
| input_variables=["context", "question"], | |
| ) | |
| quiz_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=False, | |
| chain_type_kwargs={"prompt": quiz_prompt}, | |
| ) | |
| logger.info("Quiz chain ready.") | |
| except Exception as e: | |
| logger.error("Error loading components", exc_info=True) | |
| raise | |
| def root(): | |
| return {"message": "API is up and running!"} | |
| def create_quiz(request: QuizRequest): | |
| try: | |
| logger.info("Generating quiz for topic: %s", request.question) | |
| result = quiz_chain.invoke({"query": request.question}) | |
| logger.info("Quiz generated successfully.") | |
| return {"quiz": result.get("result")} | |
| except Exception as e: | |
| logger.error("Error generating quiz", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |