rameshmoorthy commited on
Commit
c5255f7
·
verified ·
1 Parent(s): 32bb56f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -14
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import gradio as gr
 
2
  from sentence_transformers import CrossEncoder
3
  import numpy as np
4
  from time import perf_counter
5
- from backend.semantic_search import table, retriever
 
 
6
  import os
7
  import logging
8
 
@@ -10,19 +13,59 @@ import logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # API Key setup (optional, depending on retriever needs)
14
  api_key = os.getenv("GROQ_API_KEY")
15
  if not api_key:
16
- gr.Warning("GROQ_API_KEY not found. Set it in 'Repository secrets' if required.")
17
  logger.error("GROQ_API_KEY not found.")
18
  else:
19
  os.environ["GROQ_API_KEY"] = api_key
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  VECTOR_COLUMN_NAME = "vector"
22
  TEXT_COLUMN_NAME = "text"
 
 
 
 
23
 
24
- def retrieve_documents(topic):
25
- gr.Warning('Retrieving documents may take a moment. Please wait.', duration=30)
 
 
 
 
 
 
 
 
 
26
  top_k_rank = 10
27
  documents = []
28
 
@@ -37,21 +80,34 @@ def retrieve_documents(topic):
37
  sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
38
  documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
39
 
40
- elapsed_time = perf_counter() - document_start
41
- return f"Retrieved {len(documents)} documents in {elapsed_time:.2f} seconds:\n\n" + "\n".join(documents)
 
 
 
 
 
 
42
 
43
- with gr.Blocks(title="Document Retrieval Test") as APP:
44
  with gr.Row():
45
- gr.Markdown("# Document Retrieval Test")
 
 
 
46
 
47
- topic = gr.Textbox(label="Enter Topic", placeholder="Enter a topic to search")
48
- output = gr.Textbox(label="Retrieved Documents", interactive=False)
49
 
50
- submit_btn = gr.Button("Retrieve Documents")
51
- submit_btn.click(fn=retrieve_documents, inputs=[topic], outputs=[output])
 
 
 
 
52
 
53
  if __name__ == "__main__":
54
- APP.queue().launch(server_name="0.0.0.0", server_port=7860)
55
  # import gradio as gr
56
  # from pathlib import Path
57
  # from tempfile import NamedTemporaryFile
 
1
  import gradio as gr
2
+ from pathlib import Path
3
  from sentence_transformers import CrossEncoder
4
  import numpy as np
5
  from time import perf_counter
6
+ from pydantic import BaseModel, Field
7
+ from phi.agent import Agent
8
+ from phi.model.groq import Groq
9
  import os
10
  import logging
11
 
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
+ # API Key setup
17
  api_key = os.getenv("GROQ_API_KEY")
18
  if not api_key:
19
+ gr.Warning("GROQ_API_KEY not found. Set it in 'Repository secrets'.")
20
  logger.error("GROQ_API_KEY not found.")
21
  else:
22
  os.environ["GROQ_API_KEY"] = api_key
23
 
24
+ # Pydantic Model for Quiz Structure
25
+ class QuizItem(BaseModel):
26
+ question: str = Field(..., description="The quiz question")
27
+ choices: list[str] = Field(..., description="List of 4 multiple-choice options")
28
+ correct_answer: str = Field(..., description="The correct choice (e.g., 'C1')")
29
+
30
+ class QuizOutput(BaseModel):
31
+ items: list[QuizItem] = Field(..., description="List of 10 quiz items")
32
+
33
+ # Initialize Agents
34
+ groq_agent = Agent(model=Groq(model="llama3-70b-8192", api_key=api_key), markdown=True)
35
+
36
+ quiz_generator = Agent(
37
+ name="Quiz Generator",
38
+ role="Generates structured quiz questions and answers",
39
+ instructions=[
40
+ "Create 10 questions with 4 choices each based on the provided topic and documents.",
41
+ "Use the specified difficulty level (easy, average, hard) to adjust question complexity.",
42
+ "Ensure questions are derived only from the provided documents.",
43
+ "Return the output in a structured format using the QuizOutput Pydantic model.",
44
+ "Each question should have a unique correct answer from the choices (labeled C1, C2, C3, C4)."
45
+ ],
46
+ model=Groq(id="llama3-70b-8192", api_key=api_key),
47
+ response_model=QuizOutput,
48
+ markdown=True
49
+ )
50
+
51
  VECTOR_COLUMN_NAME = "vector"
52
  TEXT_COLUMN_NAME = "text"
53
+ proj_dir = Path.cwd()
54
+
55
+ # Calling functions from backend (assuming they exist)
56
+ from backend.semantic_search import table, retriever
57
 
58
+ def generate_quiz_data(question_difficulty, topic, documents_str):
59
+ prompt = f"""Generate a quiz with {question_difficulty} difficulty on topic '{topic}' using only the following documents:\n{documents_str}"""
60
+ try:
61
+ response = quiz_generator.run(prompt)
62
+ return response.content
63
+ except Exception as e:
64
+ logger.error(f"Failed to generate quiz: {e}")
65
+ return None
66
+
67
+ def retrieve_and_generate_quiz(question_difficulty, topic):
68
+ gr.Warning('Generating quiz may take 1-2 minutes. Please wait.', duration=60)
69
  top_k_rank = 10
70
  documents = []
71
 
 
80
  sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
81
  documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
82
 
83
+ documents_str = '\n'.join(documents)
84
+ quiz_data = generate_quiz_data(question_difficulty, topic, documents_str)
85
+ if not quiz_data or not quiz_data.items:
86
+ return "Error: Failed to generate quiz."
87
+
88
+ # Display only the questions as plain text
89
+ questions = [f"{i}. {item.question}" for i, item in enumerate(quiz_data.items[:10], 1)]
90
+ return "\n".join(questions)
91
 
92
+ with gr.Blocks(title="Quiz Generator Test") as QUIZBOT:
93
  with gr.Row():
94
+ gr.Markdown("# Quiz Generator Test")
95
+
96
+ topic = gr.Textbox(label="Enter Topic", placeholder="Write any topic from 9th Science CBSE")
97
+ difficulty_radio = gr.Radio(["easy", "average", "hard"], label="How difficult should the quiz be?")
98
 
99
+ generate_quiz_btn = gr.Button("Generate Quiz")
100
+ quiz_output = gr.Textbox(label="Quiz Questions", interactive=False)
101
 
102
+ # Register the click event without @ decorator
103
+ generate_quiz_btn.click(
104
+ fn=retrieve_and_generate_quiz,
105
+ inputs=[difficulty_radio, topic],
106
+ outputs=[quiz_output]
107
+ )
108
 
109
  if __name__ == "__main__":
110
+ QUIZBOT.queue().launch(server_name="0.0.0.0", server_port=7860)
111
  # import gradio as gr
112
  # from pathlib import Path
113
  # from tempfile import NamedTemporaryFile