princemaxp commited on
Commit
dbd06e6
·
verified ·
1 Parent(s): ac8d37b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -105
app.py CHANGED
@@ -1,120 +1,117 @@
 
1
  import time
2
- from collections import defaultdict
3
  import gradio as gr
4
- from datasets import load_dataset, Dataset
5
- from sentence_transformers import SentenceTransformer, util
6
- import requests
7
- import os
8
-
9
- # =======================
10
- # Configuration
11
- # =======================
12
- HF_TOKEN = os.environ.get("HF_TOKEN")
13
  DATASET_NAME = "guardian-ai-qna"
14
- RENDER_API_URL = "https://your-render-api.com/get_answer" # Replace with your Render API
15
- MAX_QUERIES_PER_HOUR = 5
16
- SIMILARITY_THRESHOLD = 0.75
17
 
18
- # =======================
19
- # Load dataset
20
- # =======================
 
 
 
 
 
 
 
21
  try:
22
- dataset = load_dataset(DATASET_NAME, use_auth_token=HF_TOKEN)["train"]
23
  except:
24
- dataset = Dataset.from_dict({"question": [], "answer": []})
25
-
26
- # Initialize embeddings
27
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
28
- if len(dataset) > 0:
29
- dataset_embeddings = embed_model.encode(dataset["question"], convert_to_tensor=True)
30
- else:
31
- dataset_embeddings = None
32
-
33
- # =======================
34
- # Rate limiting
35
- # =======================
36
- user_queries = defaultdict(list) # {session_id: [timestamps]}
37
-
38
- def check_rate_limit(session_id):
39
- now = time.time()
40
- # Keep only queries in the last hour
41
- user_queries[session_id] = [t for t in user_queries[session_id] if now - t < 3600]
42
- if len(user_queries[session_id]) >= MAX_QUERIES_PER_HOUR:
43
- return False, 3600 - (now - user_queries[session_id][0])
44
- user_queries[session_id].append(now)
45
  return True, 0
46
 
47
- # =======================
48
- # Dataset search
49
- # =======================
50
- def find_in_dataset(user_input):
51
- global dataset_embeddings
52
- if dataset_embeddings is None or len(dataset_embeddings) == 0:
53
- return None
54
- user_emb = embed_model.encode(user_input, convert_to_tensor=True)
55
- cos_scores = util.cos_sim(user_emb, dataset_embeddings)[0]
56
- top_idx = cos_scores.argmax().item()
57
- if cos_scores[top_idx] < SIMILARITY_THRESHOLD:
58
  return None
59
- return dataset["answer"][top_idx]
 
 
 
60
 
61
- # =======================
62
- # Save Q&A to dataset
63
- # =======================
64
  def save_qna(question, answer):
65
- global dataset, dataset_embeddings
66
- new_entry = Dataset.from_dict({"question": [question], "answer": [answer]})
67
- dataset = Dataset.from_dict({
68
- "question": dataset["question"] + new_entry["question"],
69
- "answer": dataset["answer"] + new_entry["answer"]
70
- })
71
- dataset_embeddings = embed_model.encode(dataset["question"], convert_to_tensor=True)
72
- dataset.push_to_hub(DATASET_NAME, token=HF_TOKEN)
73
-
74
- # =======================
75
- # Render API fallback
76
- # =======================
77
- def call_render_api(question):
78
- try:
79
- response = requests.post(RENDER_API_URL, json={"question": question}, timeout=10)
80
- if response.status_code == 200:
81
- return response.json().get("answer", "Sorry, no answer found.")
82
- except Exception as e:
83
- print("Render API error:", e)
84
- return "Sorry, no answer found."
85
-
86
- # =======================
87
- # Chat function
88
- # =======================
89
- def chat(history, user_input, session_id):
90
- allowed, wait_time = check_rate_limit(session_id)
 
91
  if not allowed:
92
- return history + [(f"Rate limit reached. Please wait {int(wait_time//60)} minutes.", "")]
93
-
94
- answer = find_in_dataset(user_input)
95
- if not answer:
96
- answer = call_render_api(user_input)
97
- save_qna(user_input, answer)
98
-
99
- history.append((user_input, answer))
100
- return history
101
-
102
- # =======================
103
- # Gradio App
104
- # =======================
105
- with gr.Blocks() as app:
106
- session_id = gr.State()
 
 
 
 
 
 
107
  chatbot = gr.Chatbot()
108
- msg = gr.Textbox(label="Ask Guardian AI")
109
- with gr.Row():
110
- clear = gr.Button("Clear Chat")
111
-
112
- def start_session():
113
- return str(time.time()) # simple session id
114
-
115
- session_id.value = start_session()
116
 
117
- msg.submit(chat, inputs=[chatbot, msg, session_id], outputs=[chatbot])
118
- clear.click(lambda: [], None, chatbot)
119
 
120
- app.launch()
 
1
+ import os
2
  import time
3
+ from datetime import datetime, timedelta
4
  import gradio as gr
5
+ from datasets import load_dataset, Dataset, DatasetDict
6
+ from huggingface_hub import HfFolder
7
+
8
+ # ================================
9
+ # CONFIG
10
+ # ================================
11
+ MODEL_TOKEN = os.environ.get("HF_TOKEN") # for model usage
12
+ DATASET_TOKEN = os.environ.get("dataset_HF_TOKEN") # for dataset updates
 
13
  DATASET_NAME = "guardian-ai-qna"
 
 
 
14
 
15
+ MAX_QUERIES = 5 # max queries per user per window
16
+ WINDOW_HOURS = 1 # time window for rate limiting
17
+
18
+ # Rate limiter store
19
+ user_queries = {}
20
+
21
+ # Save dataset token for pushes
22
+ HfFolder.save_token(DATASET_TOKEN)
23
+
24
+ # Load or create dataset
25
  try:
26
+ dataset = load_dataset(DATASET_NAME, use_auth_token=DATASET_TOKEN)
27
  except:
28
+ dataset = DatasetDict({"train": Dataset.from_dict({"question": [], "answer": []})})
29
+
30
+ # ================================
31
+ # HELPER FUNCTIONS
32
+ # ================================
33
+
34
+ def check_rate_limit(user_id):
35
+ now = datetime.now()
36
+ queries = user_queries.get(user_id, [])
37
+ # Remove expired queries
38
+ queries = [q for q in queries if q > now - timedelta(hours=WINDOW_HOURS)]
39
+ user_queries[user_id] = queries
40
+
41
+ if len(queries) >= MAX_QUERIES:
42
+ next_allowed = min(queries) + timedelta(hours=WINDOW_HOURS)
43
+ wait_seconds = int((next_allowed - now).total_seconds())
44
+ return False, wait_seconds
 
 
 
 
45
  return True, 0
46
 
47
+ def log_query(user_id):
48
+ now = datetime.now()
49
+ user_queries.setdefault(user_id, []).append(now)
50
+
51
+ def find_in_dataset(question):
52
+ if len(dataset["train"]) == 0:
 
 
 
 
 
53
  return None
54
+ for entry in dataset["train"]:
55
+ if question.strip().lower() == entry["question"].strip().lower():
56
+ return entry["answer"]
57
+ return None
58
 
 
 
 
59
  def save_qna(question, answer):
60
+ global dataset
61
+ new_entry = {"question": [question], "answer": [answer]}
62
+ new_ds = Dataset.from_dict(new_entry)
63
+ dataset["train"] = dataset["train"].concatenate(new_ds)
64
+ dataset["train"].push_to_hub(DATASET_NAME, token=DATASET_TOKEN)
65
+
66
+ def call_render(question):
67
+ """
68
+ Replace this with your actual Render API call logic
69
+ that fetches the answer from the internet.
70
+ """
71
+ import requests
72
+ RENDER_API_URL = os.environ.get("RENDER_API_URL")
73
+ if not RENDER_API_URL:
74
+ return "Render API not configured."
75
+ resp = requests.post(RENDER_API_URL, json={"question": question})
76
+ if resp.status_code == 200:
77
+ return resp.json().get("answer", "No answer found.")
78
+ return "Error fetching answer from Render."
79
+
80
+ # ================================
81
+ # CHAT FUNCTION
82
+ # ================================
83
+
84
+ def chat(history, message, session_id):
85
+ # Rate limit
86
+ allowed, wait_seconds = check_rate_limit(session_id)
87
  if not allowed:
88
+ return history + [(f"System", f"Rate limit reached. Try again in {wait_seconds//60} minutes.")], ""
89
+
90
+ log_query(session_id)
91
+
92
+ # Check dataset first
93
+ response = find_in_dataset(message)
94
+ if response is None:
95
+ # Call Render API fallback
96
+ response = call_render(message)
97
+ # Save in dataset
98
+ save_qna(message, response)
99
+
100
+ history.append(("User", message))
101
+ history.append(("Guardian AI", response))
102
+ return history, ""
103
+
104
+ # ================================
105
+ # GRADIO UI
106
+ # ================================
107
+ with gr.Blocks() as demo:
108
+ gr.Markdown("## Guardian AI Chatbot")
109
  chatbot = gr.Chatbot()
110
+ session_id = gr.Textbox(label="Session ID (unique per user)", value=str(time.time()), visible=False)
111
+ msg = gr.Textbox(label="Enter your message")
112
+ send_btn = gr.Button("Send")
 
 
 
 
 
113
 
114
+ send_btn.click(fn=chat, inputs=[chatbot, msg, session_id], outputs=[chatbot, msg])
115
+ msg.submit(fn=chat, inputs=[chatbot, msg, session_id], outputs=[chatbot, msg])
116
 
117
+ demo.launch(server_name="0.0.0.0", server_port=7860)