princemaxp commited on
Commit
ac8d37b
·
verified ·
1 Parent(s): d395fd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -100
app.py CHANGED
@@ -1,76 +1,66 @@
1
  import time
 
2
  import gradio as gr
3
  from datasets import load_dataset, Dataset
4
- from huggingface_hub import hf_hub_download
5
  from sentence_transformers import SentenceTransformer, util
6
- import torch
 
7
 
8
- # ---------------------------
9
- # CONFIGURATION
10
- # ---------------------------
11
- HF_TOKEN = "<YOUR_HF_TOKEN>" # set your HF token
12
  DATASET_NAME = "guardian-ai-qna"
13
- MAX_QUESTIONS = 5 # max questions per TIME_WINDOW
14
- TIME_WINDOW = 3600 # 1 hour in seconds
15
- EMBED_MODEL = "all-MiniLM-L6-v2" # small but effective embedding model
16
 
17
- # ---------------------------
18
- # LOAD OR CREATE DATASET
19
- # ---------------------------
20
  try:
21
- dataset = load_dataset(DATASET_NAME, use_auth_token=HF_TOKEN)
22
- dataset = dataset["train"]
23
  except:
24
  dataset = Dataset.from_dict({"question": [], "answer": []})
25
 
26
- # ---------------------------
27
- # EMBEDDING MODEL
28
- # ---------------------------
29
- embedder = SentenceTransformer(EMBED_MODEL)
30
-
31
- # Precompute embeddings for existing Q&A
32
  if len(dataset) > 0:
33
- dataset_embeddings = embedder.encode(dataset["question"], convert_to_tensor=True)
34
  else:
35
- dataset_embeddings = torch.empty((0, embedder.get_sentence_embedding_dimension()))
36
 
37
- # ---------------------------
38
- # USER RATE LIMITING
39
- # ---------------------------
40
- user_limits = {}
41
 
42
  def check_rate_limit(session_id):
43
- current_time = time.time()
44
- if session_id not in user_limits:
45
- user_limits[session_id] = {"count": 0, "start_time": current_time}
46
-
47
- user_data = user_limits[session_id]
48
- if current_time - user_data["start_time"] > TIME_WINDOW:
49
- user_data["count"] = 0
50
- user_data["start_time"] = current_time
51
-
52
- if user_data["count"] >= MAX_QUESTIONS:
53
- return False, f"You have reached the max of {MAX_QUESTIONS} questions. Please wait before asking more."
54
-
55
- user_data["count"] += 1
56
- return True, None
57
-
58
- # ---------------------------
59
- # HELPER FUNCTIONS
60
- # ---------------------------
61
- def find_similar_answer(user_input):
62
- if len(dataset) == 0:
63
  return None
64
-
65
- query_emb = embedder.encode(user_input, convert_to_tensor=True)
66
- scores = util.cos_sim(query_emb, dataset_embeddings)
67
- top_idx = torch.argmax(scores)
68
- top_score = scores[0][top_idx].item()
69
-
70
- if top_score > 0.6: # threshold for similarity
71
- return dataset["answer"][top_idx]
72
- return None
73
 
 
 
 
74
  def save_qna(question, answer):
75
  global dataset, dataset_embeddings
76
  new_entry = Dataset.from_dict({"question": [question], "answer": [answer]})
@@ -78,53 +68,53 @@ def save_qna(question, answer):
78
  "question": dataset["question"] + new_entry["question"],
79
  "answer": dataset["answer"] + new_entry["answer"]
80
  })
81
-
82
- # update embeddings incrementally
83
- new_emb = embedder.encode([question], convert_to_tensor=True)
84
- if len(dataset_embeddings) == 0:
85
- dataset_embeddings = new_emb
86
- else:
87
- dataset_embeddings = torch.vstack([dataset_embeddings, new_emb])
88
-
89
- # save to HF dataset (push to hub)
90
  dataset.push_to_hub(DATASET_NAME, token=HF_TOKEN)
91
 
92
- # ---------------------------
93
- # MAIN CHAT FUNCTION
94
- # ---------------------------
95
- def chat(history, user_input, session_id="default"):
96
- # Rate limit check
97
- allowed, message = check_rate_limit(session_id)
 
 
 
 
 
 
 
 
 
 
 
98
  if not allowed:
99
- history.append(("System", message))
100
- return history, history
101
-
102
- # Check existing similar Q&A
103
- response = find_similar_answer(user_input)
104
-
105
- if not response:
106
- # Fallback / simple generative response
107
- response = f"Guardian AI: Sorry, I don’t know the answer yet. I’m learning!"
108
-
109
- # Save new Q&A for incremental learning
110
- save_qna(user_input, response)
111
-
112
- # Update chat history
113
- history.append((user_input, response))
114
- return history, history
115
-
116
- # ---------------------------
117
- # GRADIO INTERFACE
118
- # ---------------------------
119
  with gr.Blocks() as app:
 
120
  chatbot = gr.Chatbot()
121
- msg = gr.Textbox(label="Your question")
122
- session_state = gr.State("default") # default session
123
-
124
- def user_submit(message, history, session_id):
125
- return chat(history, message, session_id)
126
-
127
- msg.submit(user_submit, inputs=[msg, chatbot, session_state], outputs=[chatbot, chatbot])
128
-
129
- # Launch app
130
- app.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
1
  import time
2
+ from collections import defaultdict
3
  import gradio as gr
4
  from datasets import load_dataset, Dataset
 
5
  from sentence_transformers import SentenceTransformer, util
6
+ import requests
7
+ import os
8
 
9
+ # =======================
10
+ # Configuration
11
+ # =======================
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
  DATASET_NAME = "guardian-ai-qna"
14
+ RENDER_API_URL = "https://your-render-api.com/get_answer" # Replace with your Render API
15
+ MAX_QUERIES_PER_HOUR = 5
16
+ SIMILARITY_THRESHOLD = 0.75
17
 
18
+ # =======================
19
+ # Load dataset
20
+ # =======================
21
  try:
22
+ dataset = load_dataset(DATASET_NAME, use_auth_token=HF_TOKEN)["train"]
 
23
  except:
24
  dataset = Dataset.from_dict({"question": [], "answer": []})
25
 
26
+ # Initialize embeddings
27
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
28
  if len(dataset) > 0:
29
+ dataset_embeddings = embed_model.encode(dataset["question"], convert_to_tensor=True)
30
  else:
31
+ dataset_embeddings = None
32
 
33
+ # =======================
34
+ # Rate limiting
35
+ # =======================
36
+ user_queries = defaultdict(list) # {session_id: [timestamps]}
37
 
38
  def check_rate_limit(session_id):
39
+ now = time.time()
40
+ # Keep only queries in the last hour
41
+ user_queries[session_id] = [t for t in user_queries[session_id] if now - t < 3600]
42
+ if len(user_queries[session_id]) >= MAX_QUERIES_PER_HOUR:
43
+ return False, 3600 - (now - user_queries[session_id][0])
44
+ user_queries[session_id].append(now)
45
+ return True, 0
46
+
47
+ # =======================
48
+ # Dataset search
49
+ # =======================
50
+ def find_in_dataset(user_input):
51
+ global dataset_embeddings
52
+ if dataset_embeddings is None or len(dataset_embeddings) == 0:
53
+ return None
54
+ user_emb = embed_model.encode(user_input, convert_to_tensor=True)
55
+ cos_scores = util.cos_sim(user_emb, dataset_embeddings)[0]
56
+ top_idx = cos_scores.argmax().item()
57
+ if cos_scores[top_idx] < SIMILARITY_THRESHOLD:
 
58
  return None
59
+ return dataset["answer"][top_idx]
 
 
 
 
 
 
 
 
60
 
61
+ # =======================
62
+ # Save Q&A to dataset
63
+ # =======================
64
  def save_qna(question, answer):
65
  global dataset, dataset_embeddings
66
  new_entry = Dataset.from_dict({"question": [question], "answer": [answer]})
 
68
  "question": dataset["question"] + new_entry["question"],
69
  "answer": dataset["answer"] + new_entry["answer"]
70
  })
71
+ dataset_embeddings = embed_model.encode(dataset["question"], convert_to_tensor=True)
 
 
 
 
 
 
 
 
72
  dataset.push_to_hub(DATASET_NAME, token=HF_TOKEN)
73
 
74
+ # =======================
75
+ # Render API fallback
76
+ # =======================
77
+ def call_render_api(question):
78
+ try:
79
+ response = requests.post(RENDER_API_URL, json={"question": question}, timeout=10)
80
+ if response.status_code == 200:
81
+ return response.json().get("answer", "Sorry, no answer found.")
82
+ except Exception as e:
83
+ print("Render API error:", e)
84
+ return "Sorry, no answer found."
85
+
86
+ # =======================
87
+ # Chat function
88
+ # =======================
89
+ def chat(history, user_input, session_id):
90
+ allowed, wait_time = check_rate_limit(session_id)
91
  if not allowed:
92
+ return history + [(f"Rate limit reached. Please wait {int(wait_time//60)} minutes.", "")]
93
+
94
+ answer = find_in_dataset(user_input)
95
+ if not answer:
96
+ answer = call_render_api(user_input)
97
+ save_qna(user_input, answer)
98
+
99
+ history.append((user_input, answer))
100
+ return history
101
+
102
+ # =======================
103
+ # Gradio App
104
+ # =======================
 
 
 
 
 
 
 
105
  with gr.Blocks() as app:
106
+ session_id = gr.State()
107
  chatbot = gr.Chatbot()
108
+ msg = gr.Textbox(label="Ask Guardian AI")
109
+ with gr.Row():
110
+ clear = gr.Button("Clear Chat")
111
+
112
+ def start_session():
113
+ return str(time.time()) # simple session id
114
+
115
+ session_id.value = start_session()
116
+
117
+ msg.submit(chat, inputs=[chatbot, msg, session_id], outputs=[chatbot])
118
+ clear.click(lambda: [], None, chatbot)
119
+
120
+ app.launch()