siddheshrj commited on
Commit
7b0a8d1
·
verified ·
1 Parent(s): 8f58ae6

Update Prompt Template (Stranger Things RAG v7)

Browse files
Files changed (1) hide show
  1. main.py +46 -77
main.py CHANGED
@@ -1,56 +1,42 @@
1
- import os
2
- from fastapi import FastAPI, Request
3
  from fastapi.responses import HTMLResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from fastapi.templating import Jinja2Templates
6
- from pydantic import BaseModel
7
  from langchain_community.vectorstores import FAISS
8
  from langchain_huggingface import HuggingFaceEmbeddings
9
- from langchain_huggingface import HuggingFaceEndpoint
 
 
 
10
  from dotenv import load_dotenv
11
 
12
- # Load environment variables
13
  load_dotenv()
14
 
15
  app = FastAPI()
16
 
17
  # Mount static files
18
  app.mount("/static", StaticFiles(directory="static"), name="static")
19
-
20
- # Templates
21
  templates = Jinja2Templates(directory="templates")
22
 
23
- # Initialize RAG Components
24
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
25
- FAISS_PATH = os.path.join(BASE_DIR, "faiss_index")
26
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
27
 
28
- print(f"DEBUG: Checking for database at {FAISS_PATH}")
29
- # Check if DB exists
30
  if os.path.exists(FAISS_PATH):
31
- print("DEBUG: Database found. Loading FAISS...")
32
- try:
33
- vector_db = FAISS.load_local(FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
34
- retriever = vector_db.as_retriever(search_kwargs={"k": 5})
35
- print("DEBUG: Retriever initialized.")
36
- except Exception as e:
37
- print(f"DEBUG: Failed to load FAISS: {e}")
38
- retriever = None
39
  else:
40
- print("WARNING: FAISS index not found at path. Run ingest.py first.")
41
- import requests
42
- import json
43
- from langchain_core.runnables import RunnableLambda
44
 
45
- # Custom DeepSeek Connector (Cloned from Reference Repo)
46
  def call_deepseek_v3(prompt_input):
47
- # Handle LangChain prompt objects
48
  if hasattr(prompt_input, "to_string"):
49
  prompt_text = prompt_input.to_string()
50
  else:
51
  prompt_text = str(prompt_input)
52
 
53
- # Direct Router API used by the reference repo
54
  api_url = "https://router.huggingface.co/v1/chat/completions"
55
  token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
56
 
@@ -67,7 +53,7 @@ def call_deepseek_v3(prompt_input):
67
  "messages": [
68
  {
69
  "role": "system",
70
- "content": "You are an expert on Stranger Things. Answer clearly and concisely."
71
  },
72
  {
73
  "role": "user",
@@ -75,93 +61,76 @@ def call_deepseek_v3(prompt_input):
75
  }
76
  ],
77
  "temperature": 0.3,
78
- "max_tokens": 512,
79
  "stream": False
80
  }
81
 
82
  try:
83
  response = requests.post(api_url, headers=headers, json=payload)
84
  if response.status_code != 200:
85
- print(f"API Error {response.status_code}: {response.text}")
86
- return f"DeepSeek Error: {response.text}"
87
  return response.json()["choices"][0]["message"]["content"]
88
  except Exception as e:
89
- print(f"DeepSeek Connection Exception: {e}")
90
- return f"Error: {e}"
91
 
92
  llm = RunnableLambda(call_deepseek_v3)
93
- print("DeepSeek V3.2 Client (Custom Request) initialized!")
94
 
95
- from langchain_core.prompts import PromptTemplate
96
- from langchain_core.output_parsers import StrOutputParser
97
- from langchain_core.runnables import RunnablePassthrough
98
 
99
- # ... (Previous LLM setup remains)
 
 
 
 
100
 
101
- # LCEL RAG Chain
102
- template = """You are an expert on Stranger Things. Use the context below to generate a natural, engaging answer in your own words.
103
- Do not just copy the text. Synthesize the information.
104
- Format your answer as a detailed response (at least 3-4 sentences).
105
- Crucial: If the question is about a character, YOU MUST INCLUDE:
106
- 1. Their key relationships (girlfriend/boyfriend, best friends).
107
- 2. Their role or passion (e.g., Dungeon Master, journalist, sheriff).
108
- 3. Any iconic traits.
 
 
 
109
 
110
  Context:
111
  {context}
112
 
113
- Question:
114
  {question}
115
 
116
- Answer:"""
 
117
 
118
  prompt = PromptTemplate.from_template(template)
119
 
120
  def format_docs(docs):
121
  return "\n\n".join(doc.page_content for doc in docs)
122
 
123
- if retriever and llm:
124
  rag_chain = (
125
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
126
  | prompt
127
  | llm
128
  )
129
- print("DEBUG: rag_chain constructed successfully.")
130
  else:
131
- print(f"DEBUG: rag_chain initialization skipped. Retriever: {retriever is not None}, LLM: {llm is not None}")
132
  rag_chain = None
133
 
134
- class QueryRequest(BaseModel):
135
- query: str
136
-
137
  @app.get("/", response_class=HTMLResponse)
138
  async def read_root(request: Request):
139
  return templates.TemplateResponse("index.html", {"request": request})
140
 
141
- @app.post("/query")
142
- async def query_rag(request: QueryRequest):
143
- print(f"DEBUG: Incoming query: {request.query}")
144
- print(f"DEBUG: rag_chain type: {type(rag_chain)}")
145
- print(f"DEBUG: rag_chain is: {rag_chain}")
146
-
147
  if not rag_chain:
148
- return {"answer": "System is initializing or data is missing. Please check server logs.", "sources": []}
149
 
150
- try:
151
- # Get answer
152
- answer = rag_chain.invoke(request.query)
153
-
154
- # Get sources separately since LCEL simple chain doesn't return them by default
155
- # unless we modify the runable to return a dict. For now, we'll re-retrieve for sources
156
- # or just skip sources to keep it simple as per user request for "|" operator specific demo.
157
- # But to keep sources, let's do a quick retrieve:
158
- source_docs = retriever.invoke(request.query)
159
- sources = [doc.metadata.get("source", "Unknown") for doc in source_docs]
160
- sources = list(set(sources))
161
-
162
- return {"answer": answer, "sources": sources}
163
- except Exception as e:
164
- return {"answer": f"Error: {str(e)}", "sources": []}
165
 
166
  if __name__ == "__main__":
167
  import uvicorn
 
1
+ from fastapi import FastAPI, Request, Form
 
2
  from fastapi.responses import HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
4
  from fastapi.templating import Jinja2Templates
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
9
+ import os
10
+ import requests
11
  from dotenv import load_dotenv
12
 
 
13
  load_dotenv()
14
 
15
  app = FastAPI()
16
 
17
  # Mount static files
18
  app.mount("/static", StaticFiles(directory="static"), name="static")
 
 
19
  templates = Jinja2Templates(directory="templates")
20
 
21
+ # Load FAISS Index
22
+ FAISS_PATH = "faiss_index"
23
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
24
 
 
 
25
  if os.path.exists(FAISS_PATH):
26
+ vector_db = FAISS.load_local(FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
27
+ retriever = vector_db.as_retriever(search_kwargs={"k": 3})
28
+ print("DEBUG: FAISS index loaded.")
 
 
 
 
 
29
  else:
30
+ print("WARNING: FAISS index not found. Run ingest.py first.")
31
+ retriever = None
 
 
32
 
33
+ # Custom DeepSeek V3.2 Wrapper
34
  def call_deepseek_v3(prompt_input):
 
35
  if hasattr(prompt_input, "to_string"):
36
  prompt_text = prompt_input.to_string()
37
  else:
38
  prompt_text = str(prompt_input)
39
 
 
40
  api_url = "https://router.huggingface.co/v1/chat/completions"
41
  token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
42
 
 
53
  "messages": [
54
  {
55
  "role": "system",
56
+ "content": "You are a Stranger Things expert. Answer clearly."
57
  },
58
  {
59
  "role": "user",
 
61
  }
62
  ],
63
  "temperature": 0.3,
64
+ "max_tokens": 700,
65
  "stream": False
66
  }
67
 
68
  try:
69
  response = requests.post(api_url, headers=headers, json=payload)
70
  if response.status_code != 200:
71
+ return f"DeepSeek Error ({response.status_code}): {response.text}"
 
72
  return response.json()["choices"][0]["message"]["content"]
73
  except Exception as e:
74
+ return f"Connection Error: {e}"
 
75
 
76
  llm = RunnableLambda(call_deepseek_v3)
 
77
 
78
+ # Updated Prompt Template (User Request)
79
+ template = """
80
+ You are a Stranger Things expert assistant. Answer the user's question using ONLY the provided context.
81
 
82
+ Important rules:
83
+ - Do NOT copy sentences directly from the context. Rewrite in your own words.
84
+ - If the context does NOT contain the answer, say: "I don’t have enough information in the provided context to answer that fully."
85
+ - Keep the tone natural, friendly, and engaging.
86
+ - Write at least 4–6 sentences unless the question is very simple.
87
 
88
+ If the question is about a CHARACTER, you MUST include:
89
+ 1) Full name + who they are in the story
90
+ 2) Key relationships (friends, family, love interest, major connections)
91
+ 3) Role / occupation / passion (student, sheriff, journalist, Dungeon Master, etc.)
92
+ 4) Iconic traits (personality, behavior, famous moments or skills)
93
+
94
+ If the question is about an EVENT / LOCATION / OBJECT, you MUST include:
95
+ 1) What it is
96
+ 2) Why it matters in the story
97
+ 3) Who is involved
98
+ 4) Any major consequences or impact
99
 
100
  Context:
101
  {context}
102
 
103
+ User Question:
104
  {question}
105
 
106
+ Answer (detailed and structured):
107
+ """
108
 
109
  prompt = PromptTemplate.from_template(template)
110
 
111
  def format_docs(docs):
112
  return "\n\n".join(doc.page_content for doc in docs)
113
 
114
+ if retriever:
115
  rag_chain = (
116
  {"context": retriever | format_docs, "question": RunnablePassthrough()}
117
  | prompt
118
  | llm
119
  )
 
120
  else:
 
121
  rag_chain = None
122
 
 
 
 
123
  @app.get("/", response_class=HTMLResponse)
124
  async def read_root(request: Request):
125
  return templates.TemplateResponse("index.html", {"request": request})
126
 
127
+ @app.post("/get_response")
128
+ async def get_response(request: Request, query: str = Form(...)):
 
 
 
 
129
  if not rag_chain:
130
+ return templates.TemplateResponse("index.html", {"request": request, "response": "System Error: RAG chain not initialized."})
131
 
132
+ result = rag_chain.invoke(query)
133
+ return templates.TemplateResponse("index.html", {"request": request, "response": result})
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
  import uvicorn