mayzinoo commited on
Commit
72a3c35
·
verified ·
1 Parent(s): 668e43c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -142
app.py CHANGED
@@ -1,145 +1,92 @@
 
 
1
  import os
 
 
 
 
 
2
  import re
3
- import zipfile
4
- import gradio as gr
5
- from langchain_openai import ChatOpenAI
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain_chroma import Chroma
8
- from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
9
- from langchain.chains import LLMChain
10
-
11
- # Unzip vector DB if not already extracted
12
- if not os.path.exists("geometry_chroma"):
13
- with zipfile.ZipFile("geometry_chroma.zip", 'r') as zip_ref:
14
- zip_ref.extractall(".")
15
-
16
- # Load vector DB
17
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
18
- vectordb = Chroma(persist_directory="geometry_chroma", embedding_function=embedding_model)
19
- retriever = vectordb.as_retriever()
20
-
21
- # Set OpenAI key (use Secrets or .env later)
22
- os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
23
-
24
- llm = ChatOpenAI(model_name="gpt-4.1", temperature=0.2)
25
-
26
- # ✅ Prompt templates
27
- templates = {
28
- "flashcard": PromptTemplate(
29
- input_variables=["context", "query"],
30
- template="""
31
- {context}
32
-
33
- Create 5 flashcards based on the topic: "{query}"
34
- Each flashcard should include:
35
- - A clear question
36
- - A short answer
37
- Focus on high school geometry understanding.
38
- """
39
- ),
40
- "lesson plan": PromptTemplate(
41
- input_variables=["context", "query"],
42
- template="""
43
- Given the following retrieved SOL text:
44
- {context}
45
-
46
- Generate a Geometry lesson plan based on: "{query}"
47
- Include:
48
- 1. Simple explanation of the concept.
49
- 2. Real-world example.
50
- 3. Engaging class activity.
51
- Be concise and curriculum-aligned for high school.
52
- """
53
- ),
54
- "worksheet": PromptTemplate(
55
- input_variables=["context", "query"],
56
- template="""
57
- {context}
58
-
59
- Create a student worksheet for: "{query}"
60
- Include:
61
- - Concept summary
62
- - A worked example
63
- - 3 practice problems
64
- """
65
- ),
66
- "proofs": PromptTemplate(
67
- input_variables=["context", "query"],
68
- template="""
69
- {context}
70
-
71
- Generate a proof-focused geometry lesson plan for: "{query}"
72
- Include:
73
- - Student-friendly explanation
74
- - Real-world connection
75
- - One short class activity
76
- """
77
- ),
78
- "general question": ChatPromptTemplate.from_messages([
79
- HumanMessagePromptTemplate.from_template(
80
- """
81
- You are a Virginia Geometry SOL assistant.
82
-
83
- From the following SOL context:
84
- {context}
85
-
86
- Identify the SOL standard (e.g., G.RLT.1) that best matches this query: "{query}"
87
-
88
- Respond with:
89
- 1. The exact SOL code (e.g., G.RLT.1)
90
- 2. The exact description line from the SOL guide
91
-
92
- Do not summarize. Only copy from the context.
93
- """
94
- )
95
- ])
96
-
97
-
98
-
99
- }
100
-
101
- def generate_prompt_output(prompt_type, query, retriever, llm):
102
- try:
103
- import re
104
- sol_match = re.search(r"\bG\.[A-Z]+\.\d+\b", query)
105
- matched_code = sol_match.group(0) if sol_match else None
106
-
107
- if matched_code:
108
- all_docs = retriever.vectorstore._collection.get(include=['documents', 'metadatas'])
109
- filtered = []
110
- for doc_text, metadata in zip(all_docs['documents'], all_docs['metadatas']):
111
- if metadata.get('standard') == matched_code:
112
- filtered.append(doc_text)
113
-
114
- context = "\n\n".join(filtered)
115
- else:
116
- docs = retriever.get_relevant_documents(query)
117
- context = "\n\n".join([doc.page_content for doc in docs])
118
-
119
- chain = LLMChain(llm=llm, prompt=templates[prompt_type])
120
- return chain.run({"context": context, "query": query}).strip()
121
-
122
- except Exception as e:
123
- return f"❌ Error: {str(e)}"
124
-
125
-
126
-
127
-
128
- # ✅ Gradio UI
129
- with gr.Blocks() as demo:
130
- gr.Markdown("# 📐 Geometry Teaching Assistant")
131
-
132
- with gr.Row():
133
- query = gr.Textbox(label="Enter a geometry topic")
134
- prompt_type = gr.Dropdown(
135
- ["general question", "lesson plan", "worksheet", "proofs", "flashcard"],
136
- value="general question",
137
- label="Prompt Type"
138
- )
139
-
140
- output = gr.Textbox(label="Generated Output", lines=12, interactive=True)
141
- btn = gr.Button("Generate")
142
-
143
- btn.click(fn=generate_prompt_output, inputs=[prompt_type, query], outputs=output)
144
 
145
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
  import os
4
+ from transformers import pipeline
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+ import numpy as np
8
+ import json
9
  import re
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # --- Load necessary components for the RAG system ---
12
+ # These paths are relative to the Space's root directory
13
+ FAISS_INDEX_PATH = "sol_faiss_index.bin"
14
+ DOCUMENT_IDS_PATH = "sol_document_ids.json"
15
+
16
+ # Load SentenceTransformer model
17
+ # Ensure this model is downloaded or available in the environment
18
+ # For Spaces, you might need to add it to requirements.txt or directly download if space has internet
19
+ # It's better to declare it globally or as a shared resource.
20
+ try:
21
+ model = SentenceTransformer('all-mpnet-base-v2')
22
+ except Exception as e:
23
+ print(f"Error loading SentenceTransformer model: {e}")
24
+ print("Attempting to load from local cache or download on first use.")
25
+ # If running in a Space, the model will be downloaded to cache if not present.
26
+ # Ensure you have internet access in your Space settings.
27
+
28
+ # Load FAISS index
29
+ try:
30
+ index = faiss.read_index(FAISS_INDEX_PATH)
31
+ except Exception as e:
32
+ print(f"Error loading FAISS index: {e}")
33
+ # Handle error, maybe create a dummy index or exit
34
+ index = None # Placeholder if loading fails
35
+
36
+ # Load document IDs
37
+ try:
38
+ with open(DOCUMENT_IDS_PATH, "r") as f:
39
+ document_ids = json.load(f)
40
+ except Exception as e:
41
+ print(f"Error loading document IDs: {e}")
42
+ document_ids = [] # Placeholder if loading fails
43
+
44
+ # Placeholder for the actual content of "10 Geometry Mathematics Instructional Guide.pdf"
45
+ # In a real deployed scenario, this content would be loaded from a file
46
+ # that you upload to your Hugging Face Space or fetched at runtime.
47
+ # For now, we'll assume it's available or that 'documents' are pre-processed and loaded.
48
+ # You would typically load the 'documents' list created in Step 2 here.
49
+ # For deployment, it's best to save the `documents` list (sol_data) as a JSON
50
+ # and load it back. Let's add that.
51
+
52
+ # Assuming you've saved sol_data as 'sol_documents.json'
53
+ SOL_DOCUMENTS_PATH = "sol_documents.json"
54
+ try:
55
+ with open(SOL_DOCUMENTS_PATH, "r") as f:
56
+ documents = json.load(f)
57
+ except Exception as e:
58
+ print(f"Error loading sol documents: {e}")
59
+ documents = [] # Placeholder
60
+
61
+ # Load LLM for generation
62
+ # For a Hugging Face Space, you need to ensure the model is available.
63
+ # 'google/gemma-2b-it' is a good option.
64
+ # Ensure you set up environment variables or secrets for API keys if using paid models.
65
+ try:
66
+ llm_pipeline = pipeline("text-generation", model="google/gemma-2b-it")
67
+ except Exception as e:
68
+ print(f"Error loading LLM pipeline: {e}")
69
+ llm_pipeline = None # Placeholder
70
+
71
+
72
+ def retrieve_and_generate_app(query, top_k=3):
73
+ if not model or not index or not document_ids or not documents or not llm_pipeline:
74
+ return "System not fully initialized. Please check logs for missing components."
75
+
76
+ # 1. Query Embedding
77
+ query_embedding = model.encode([query])
78
+
79
+ # 2. Retrieval using FAISS
80
+ D, I = index.search(query_embedding, top_k)
81
+
82
+ retrieved_docs = []
83
+ for i in I[0]:
84
+ sol_id = document_ids[i]
85
+ # Find the full content of the retrieved SOL
86
+ # This relies on the 'documents' list being correctly loaded and matching by ID
87
+ retrieved_content = next((doc["content"] for doc in documents if doc["id"] == sol_id), "Content not found.")
88
+ retrieved_docs.append({"id": sol_id, "content": retrieved_content})
89
+
90
+ context = "\n\n".join([f"SOL {doc['id']}: {doc['content']}" for doc in retrieved_docs])
91
+
92
+ prompt = f"""