Hammad712 commited on
Commit
2748d2d
·
verified ·
1 Parent(s): 03b42ab

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +147 -92
main.py CHANGED
@@ -1,104 +1,159 @@
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from langchain.document_loaders import WikipediaLoader
4
- from langchain_groq import ChatGroq
5
- import os
6
-
7
- app = FastAPI(title="Quiz Generator API")
8
 
9
- # in-memory store for the last quiz + context
10
- STORE = {
11
- "quiz": None, # str
12
- "context": None, # str
13
- }
14
-
15
- # Replace with your actual Groq API key
16
- GROQ_API_KEY = os.getenv('api_key')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class QuizRequest(BaseModel):
19
- search_query: str
20
 
21
  class GradeRequest(BaseModel):
22
- answers: str
23
-
24
- def get_llm():
25
- return ChatGroq(
26
- model="meta-llama/llama-4-scout-17b-16e-instruct",
27
- temperature=0,
28
- max_tokens=1024,
29
- api_key=GROQ_API_KEY
30
- )
31
 
32
- def wikipedia_query(search_query: str):
 
 
33
  try:
34
- docs = WikipediaLoader(query=search_query, load_max_docs=2).load()
35
- return docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
- raise HTTPException(status_code=500, detail=f"Wikipedia query failed: {e}")
 
38
 
39
  @app.get("/")
40
- async def root():
41
- return {
42
- "message": "Welcome to the Quiz Generator API! \n"
43
- "POST /generate_quiz/ to create a new quiz. \n"
44
- "POST /grade_quiz/ to grade your answers against the last quiz."
45
- }
46
-
47
- @app.post("/generate_quiz/")
48
- async def generate_quiz(request: QuizRequest):
49
- # fetch & store context
50
-
51
- llm = get_llm()
52
- search_query=llm.invoke(f'write one word which should be search for this question {request.search_query}')
53
-
54
- context_docs = wikipedia_query(search_query.content)
55
- context_text = str(context_docs)
56
- STORE["context"] = context_text
57
-
58
- # generate quiz
59
-
60
- prompt = f"""
61
- You are a quiz generator assistant. Create a quiz for the given context.
62
-
63
- Instructions:
64
- - Do not write answers in the quiz.
65
- - Quiz should be based on the following context:
66
-
67
- context:
68
- {context_text}
69
-
70
- question:
71
- generate quiz on {request.search_query}
72
-
73
- Your response:
74
- """
75
- result = llm.invoke(prompt)
76
- STORE["quiz"] = result.content
77
-
78
- return {
79
- "quiz": result.content
80
- }
81
-
82
- @app.post("/grade_quiz/")
83
- async def grade_quiz(request: GradeRequest):
84
- # ensure we have a quiz to grade
85
- if STORE["quiz"] is None or STORE["context"] is None:
86
- raise HTTPException(status_code=400, detail="No quiz available. Call /generate_quiz/ first.")
87
-
88
- llm = get_llm()
89
- prompt = f"""
90
- Check the quiz answers and give marks and also provide the correct answers. Use the following context to check the quiz. Return only the total mark and the correct answers.
91
- Do not write anything else expect the marks and correct answers do not generate new quiz.
92
- quiz:
93
- {STORE['quiz']}
94
-
95
- answers:
96
- {request.answers}
97
-
98
- context:
99
- {STORE['context']}
100
- """
101
- result = llm.invoke(prompt)
102
- return {
103
- "grade": result.content
104
- }
 
1
+ import os
2
+ import zipfile
3
+ import logging
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
 
 
 
 
 
6
 
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ from langchain_groq import ChatGroq
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.prompts import PromptTemplate
12
+
13
+ # Configure logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ app = FastAPI()
21
+
22
+ # === Globals ===
23
+ llm = None
24
+ embeddings = None
25
+ vectorstore = None
26
+ retriever = None
27
+ quiz_chain = None
28
+ grade_chain = None
29
 
30
  class QuizRequest(BaseModel):
31
+ topic: str
32
 
33
  class GradeRequest(BaseModel):
34
+ answers: str # string of Q/A pairs
 
 
 
 
 
 
 
 
35
 
36
+ @app.on_event("startup")
37
+ def load_components():
38
+ global llm, embeddings, vectorstore, retriever, quiz_chain, grade_chain
39
  try:
40
+ api_key = os.getenv("API_KEY")
41
+ if not api_key:
42
+ logger.error("API_KEY environment variable is not set or empty.")
43
+ raise RuntimeError("API_KEY environment variable is not set or empty.")
44
+ logger.info("API_KEY is set.")
45
+
46
+ # 1) Init LLM & Embeddings
47
+ llm = ChatGroq(
48
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
49
+ temperature=0,
50
+ max_tokens=1024,
51
+ api_key=api_key,
52
+ )
53
+ embeddings = HuggingFaceEmbeddings(
54
+ model_name="intfloat/multilingual-e5-large",
55
+ model_kwargs={"device": "cpu"},
56
+ encode_kwargs={"normalize_embeddings": True},
57
+ )
58
+
59
+ # 2) Load FAISS indexes
60
+ for zip_name, dir_name in [("faiss_index.zip", "faiss_index"), ("faiss_index(1).zip", "faiss_index_extra")]:
61
+ if not os.path.exists(dir_name):
62
+ with zipfile.ZipFile(zip_name, 'r') as z:
63
+ z.extractall(dir_name)
64
+ logger.info(f"Unzipped {zip_name} to {dir_name}.")
65
+ else:
66
+ logger.info(f"Directory {dir_name} already exists.")
67
+
68
+ vs1 = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
69
+ logger.info("FAISS index 1 loaded.")
70
+ vs2 = FAISS.load_local("faiss_index_extra", embeddings, allow_dangerous_deserialization=True)
71
+ logger.info("FAISS index 2 loaded.")
72
+
73
+ vs1.merge_from(vs2)
74
+ vectorstore = vs1
75
+ logger.info("Merged FAISS indexes into a single vectorstore.")
76
+
77
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
78
+
79
+ # Quiz generation chain
80
+ quiz_prompt = PromptTemplate(
81
+ template="""
82
+ Generate a quiz on the topic "{topic}" using **only** the information in the "Retrieved context".
83
+ Include clear questions and multiple-choice options (A, B, C, D).
84
+ If context is insufficient, reply with "I don't know".
85
+
86
+ Retrieved context:
87
+ {context}
88
+
89
+ Quiz topic:
90
+ {topic}
91
+
92
+ Quiz:
93
+ """,
94
+ input_variables=["context", "topic"],
95
+ )
96
+ quiz_chain = RetrievalQA.from_chain_type(
97
+ llm=llm,
98
+ chain_type="stuff",
99
+ retriever=retriever,
100
+ return_source_documents=False,
101
+ chain_type_kwargs={"prompt": quiz_prompt},
102
+ )
103
+ logger.info("Quiz chain ready.")
104
+
105
+ # Grade quiz chain
106
+ grade_prompt = PromptTemplate(
107
+ template="""
108
+ Grade the following quiz answers based on the "Retrieved context".
109
+ Provide a score and brief feedback for each question.
110
+ If context is insufficient to grade, say "I don't know" for that question.
111
+
112
+ Retrieved context:
113
+ {context}
114
+
115
+ User answers:
116
+ {answers}
117
+
118
+ Grading:
119
+ """,
120
+ input_variables=["context", "answers"],
121
+ )
122
+ grade_chain = RetrievalQA.from_chain_type(
123
+ llm=llm,
124
+ chain_type="stuff",
125
+ retriever=retriever,
126
+ return_source_documents=False,
127
+ chain_type_kwargs={"prompt": grade_prompt},
128
+ )
129
+ logger.info("Grading chain ready.")
130
+
131
  except Exception as e:
132
+ logger.error("Error loading components", exc_info=True)
133
+ raise
134
 
135
  @app.get("/")
136
+ def root():
137
+ return {"message": "API is up and running!"}
138
+
139
+ @app.post("/quiz")
140
+ def create_quiz(request: QuizRequest):
141
+ try:
142
+ logger.info("Generating quiz for topic: %s", request.topic)
143
+ result = quiz_chain.invoke({"query": request.topic})
144
+ logger.info("Quiz generated successfully.")
145
+ return {"quiz": result.get("result")}
146
+ except Exception as e:
147
+ logger.error("Error generating quiz", exc_info=True)
148
+ raise HTTPException(status_code=500, detail=str(e))
149
+
150
+ @app.post("/grade")
151
+ def grade_quiz(request: GradeRequest):
152
+ try:
153
+ logger.info("Grading quiz with provided answers.")
154
+ result = grade_chain.invoke({"query": request.answers})
155
+ logger.info("Quiz graded successfully.")
156
+ return {"grading": result.get("result")}
157
+ except Exception as e:
158
+ logger.error("Error grading quiz", exc_info=True)
159
+ raise HTTPException(status_code=500, detail=str(e))