rairo commited on
Commit
dd33104
·
verified ·
1 Parent(s): e07e4ef

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +64 -79
main.py CHANGED
@@ -5,122 +5,107 @@ import numpy as np
5
  import pickle
6
  from flask import Flask, request, jsonify
7
  from flask_cors import CORS
8
- import google.generativeai as genai
9
  import firebase_admin
10
  from firebase_admin import credentials, firestore
11
  from dotenv import load_dotenv
12
 
 
 
 
13
  load_dotenv()
14
 
15
- # --------- Flask Setup ---------
16
  app = Flask(__name__)
17
  CORS(app)
18
 
19
- # --------- Firebase Initialization ---------
20
  cred_json = os.environ.get("FIREBASE")
21
  if cred_json:
22
  cred = credentials.Certificate(json.loads(cred_json))
23
  firebase_admin.initialize_app(cred)
24
  fs = firestore.client()
25
 
26
- # --------- Gemini Configuration ---------
27
- genai.configure(api_key=os.getenv("Gemini"))
28
- chat_model = genai.GenerativeModel("gemini-2.0-flash-thinking-exp")
29
- # --------- Paths for Cached Index ---------
30
- INDEX_PATH = "vector.index"
31
- DOCS_PATH = "documents.pkl"
32
-
33
- # --------- Load Documents from Firestore ---------
34
- def fetch_documents():
35
- documents = []
36
-
37
- for doc in fs.collection("participants").stream():
38
- d = doc.to_dict()
39
- documents.append(f"{d.get('name')} ({d.get('enterpriseName')}), sector: {d.get('sector')}, stage: {d.get('stage')}, type: {d.get('developmentType')}.")
40
-
41
- for doc in fs.collection("interventions").stream():
42
- d = doc.to_dict()
43
- for item in d.get("interventions", []):
44
- documents.append(f"Intervention: {item.get('title')} under {d.get('area')}.")
45
 
46
- for doc in fs.collection("feedbacks").stream():
47
- d = doc.to_dict()
48
- documents.append(f"Feedback on {d.get('interventionTitle')} by {d.get('smeName')}: {d.get('comment')}")
49
-
50
- for doc in fs.collection("complianceDocuments").stream():
51
- d = doc.to_dict()
52
- documents.append(f"Compliance document '{d.get('documentType')}' for {d.get('participantName')} is {d.get('status')} and expires on {d.get('expiryDate')}.")
53
-
54
- for doc in fs.collection("assignedInterventions").stream():
55
- d = doc.to_dict()
56
- documents.append(f"Assigned intervention '{d.get('interventionTitle')}' for {d.get('smeName')} by consultant {d.get('consultantId')} with status {d.get('status')}.")
57
-
58
- for doc in fs.collection("consultants").stream():
59
- d = doc.to_dict()
60
- documents.append(f"Consultant {d.get('name')} with expertise in {', '.join(d.get('expertise', []))} and rating {d.get('rating')}.")
61
-
62
- return documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # --------- FAISS Caching ---------
65
  def build_or_load_index():
66
  if os.path.exists(INDEX_PATH) and os.path.exists(DOCS_PATH):
67
- print("Loading FAISS index and documents from cache...")
68
  with open(DOCS_PATH, "rb") as f:
69
  documents = pickle.load(f)
70
  index = faiss.read_index(INDEX_PATH)
71
  else:
72
- print("Building FAISS index...")
73
  documents = fetch_documents()
74
- embeddings = np.array(get_embeddings(documents), dtype="float32")
75
- dimension = len(embeddings[0])
76
- index = faiss.IndexFlatIP(dimension)
77
- index.add(embeddings)
78
-
79
- # Save cache
80
  with open(DOCS_PATH, "wb") as f:
81
  pickle.dump(documents, f)
82
  faiss.write_index(index, INDEX_PATH)
83
  return documents, index
84
 
85
- def get_embeddings(texts):
86
- response= genai.embed_content(
87
- model="models/text-embedding-004", content=texts, output_dimensionality=10
88
- )
89
- return [e.values for e in response.embeddings]
90
-
91
  documents, index = build_or_load_index()
92
 
93
- # --------- Helper Function: RAG Chat ---------
94
- def retrieve_and_respond(user_query, top_k=3):
95
- query_embedding = np.array(get_embeddings([user_query]), dtype="float32")
96
- distances, indices = index.search(query_embedding, top_k)
97
- retrieved_docs = [documents[i] for i in indices[0]]
98
-
99
- prompt = (
100
- "Use the following context to answer the question.\n\n"
101
- + "\n\n".join(retrieved_docs)
102
- + f"\n\nQuestion: {user_query}\nAnswer:"
103
- )
104
-
105
- chat = chat_model.start_chat()
106
- response = chat.send_message(prompt)
107
- return getattr(response, "text", response.last.text)
108
-
109
- # --------- Flask Chat Endpoint ---------
110
  @app.route("/chat", methods=["POST"])
111
- def chat():
112
  data = request.get_json(force=True)
113
- user_query = data.get("user_query")
114
-
115
- if not user_query:
116
  return jsonify({"error": "Missing user_query"}), 400
117
-
118
  try:
119
- reply = retrieve_and_respond(user_query)
120
- return jsonify({"reply": reply})
121
  except Exception as e:
122
  return jsonify({"error": str(e)}), 500
123
 
124
- # --------- Run Flask Server ---------
125
  if __name__ == "__main__":
126
  app.run(host="0.0.0.0", port=7860, debug=True)
 
5
  import pickle
6
  from flask import Flask, request, jsonify
7
  from flask_cors import CORS
 
8
  import firebase_admin
9
  from firebase_admin import credentials, firestore
10
  from dotenv import load_dotenv
11
 
12
+ from google import genai
13
+ from google.genai import types
14
+
15
  load_dotenv()
16
 
17
+ # --------- Flask & Firebase Setup ---------
18
  app = Flask(__name__)
19
  CORS(app)
20
 
 
21
  cred_json = os.environ.get("FIREBASE")
22
  if cred_json:
23
  cred = credentials.Certificate(json.loads(cred_json))
24
  firebase_admin.initialize_app(cred)
25
  fs = firestore.client()
26
 
27
+ # --------- Google GenAI Client ---------
28
+ client = genai.Client(api_key=os.getenv("Gemini"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # --------- FAISS Cache Paths ---------
31
+ INDEX_PATH = "vector.index"
32
+ DOCS_PATH = "documents.pkl"
33
+
34
+ # --------- Fetch & Summarize Firestore Docs ---------
35
+ def fetch_documents() -> list[str]:
36
+ docs = []
37
+ for col in [
38
+ ("participants", lambda d: f"{d['name']} ({d['enterpriseName']}), sector: {d['sector']}, stage: {d['stage']}, type: {d['developmentType']}."),
39
+ ("interventions", lambda d: [f"Intervention: {i['title']} under {d['area']}." for i in d.get("interventions", [])]),
40
+ ("feedbacks", lambda d: f"Feedback on {d['interventionTitle']} by {d['smeName']}: {d['comment']}"),
41
+ ("complianceDocuments", lambda d: f"Compliance document '{d['documentType']}' for {d['participantName']} is {d['status']} (expires {d['expiryDate']})."),
42
+ ("assignedInterventions", lambda d: f"Assigned '{d['interventionTitle']}' for {d['smeName']} by consultant {d['consultantId']} ({d['status']})."),
43
+ ("consultants", lambda d: f"Consultant {d['name']} – expertise: {', '.join(d.get('expertise', []))}, rating: {d['rating']}.")
44
+ ]:
45
+ for snap in fs.collection(col[0]).stream():
46
+ entry = snap.to_dict()
47
+ out = col[1](entry)
48
+ # flatten lists
49
+ if isinstance(out, list):
50
+ docs.extend(out)
51
+ else:
52
+ docs.append(out)
53
+ return docs
54
+
55
+ # --------- Embedding Helper ---------
56
+ def get_embeddings(texts: list[str]) -> list[list[float]]:
57
+ resp = client.models.embed_content(
58
+ model="text-embedding-004",
59
+ contents=texts
60
+ # , config=types.EmbedContentConfig(output_dimensionality=512)
61
+ )
62
+ return [emb.values for emb in resp.embeddings]
63
 
64
+ # --------- Build or Load FAISS Index ---------
65
  def build_or_load_index():
66
  if os.path.exists(INDEX_PATH) and os.path.exists(DOCS_PATH):
 
67
  with open(DOCS_PATH, "rb") as f:
68
  documents = pickle.load(f)
69
  index = faiss.read_index(INDEX_PATH)
70
  else:
 
71
  documents = fetch_documents()
72
+ embs = np.array(get_embeddings(documents), dtype="float32")
73
+ dim = embs.shape[1]
74
+ index = faiss.IndexFlatIP(dim)
75
+ index.add(embs)
76
+ # cache to disk
 
77
  with open(DOCS_PATH, "wb") as f:
78
  pickle.dump(documents, f)
79
  faiss.write_index(index, INDEX_PATH)
80
  return documents, index
81
 
 
 
 
 
 
 
82
  documents, index = build_or_load_index()
83
 
84
+ # --------- RAG Chat Helper ---------
85
+ def retrieve_and_respond(user_query: str, top_k: int = 3) -> str:
86
+ # 1) Embed query
87
+ q_emb = np.array(get_embeddings([user_query]), dtype="float32")
88
+ # 2) Search index
89
+ _, idxs = index.search(q_emb, top_k)
90
+ ctx = "\n\n".join(documents[i] for i in idxs[0])
91
+ # 3) Build prompt
92
+ prompt = f"Use the context below to answer:\n\n{ctx}\n\nQuestion: {user_query}\nAnswer:"
93
+ # 4) Chat
94
+ chat = client.chats.create(model="gemini-2.0-flash-thinking-exp")
95
+ resp = chat.send_message(prompt)
96
+ return resp.text
97
+
98
+ # --------- Flask Endpoint ---------
 
 
99
  @app.route("/chat", methods=["POST"])
100
+ def chat_endpoint():
101
  data = request.get_json(force=True)
102
+ q = data.get("user_query")
103
+ if not q:
 
104
  return jsonify({"error": "Missing user_query"}), 400
 
105
  try:
106
+ return jsonify({"reply": retrieve_and_respond(q)})
 
107
  except Exception as e:
108
  return jsonify({"error": str(e)}), 500
109
 
 
110
  if __name__ == "__main__":
111
  app.run(host="0.0.0.0", port=7860, debug=True)