princemaxp commited on
Commit
78db3ec
·
verified ·
1 Parent(s): dbd06e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -4,11 +4,14 @@ from datetime import datetime, timedelta
4
  import gradio as gr
5
  from datasets import load_dataset, Dataset, DatasetDict
6
  from huggingface_hub import HfFolder
 
 
 
7
 
8
  # ================================
9
  # CONFIG
10
  # ================================
11
- MODEL_TOKEN = os.environ.get("HF_TOKEN") # for model usage
12
  DATASET_TOKEN = os.environ.get("dataset_HF_TOKEN") # for dataset updates
13
  DATASET_NAME = "guardian-ai-qna"
14
 
@@ -25,7 +28,10 @@ HfFolder.save_token(DATASET_TOKEN)
25
  try:
26
  dataset = load_dataset(DATASET_NAME, use_auth_token=DATASET_TOKEN)
27
  except:
28
- dataset = DatasetDict({"train": Dataset.from_dict({"question": [], "answer": []})})
 
 
 
29
 
30
  # ================================
31
  # HELPER FUNCTIONS
@@ -48,17 +54,27 @@ def log_query(user_id):
48
  now = datetime.now()
49
  user_queries.setdefault(user_id, []).append(now)
50
 
51
- def find_in_dataset(question):
52
  if len(dataset["train"]) == 0:
53
  return None
54
- for entry in dataset["train"]:
55
- if question.strip().lower() == entry["question"].strip().lower():
56
- return entry["answer"]
 
 
 
 
 
 
 
 
57
  return None
58
 
59
  def save_qna(question, answer):
60
  global dataset
61
- new_entry = {"question": [question], "answer": [answer]}
 
 
62
  new_ds = Dataset.from_dict(new_entry)
63
  dataset["train"] = dataset["train"].concatenate(new_ds)
64
  dataset["train"].push_to_hub(DATASET_NAME, token=DATASET_TOKEN)
@@ -68,7 +84,6 @@ def call_render(question):
68
  Replace this with your actual Render API call logic
69
  that fetches the answer from the internet.
70
  """
71
- import requests
72
  RENDER_API_URL = os.environ.get("RENDER_API_URL")
73
  if not RENDER_API_URL:
74
  return "Render API not configured."
@@ -89,7 +104,7 @@ def chat(history, message, session_id):
89
 
90
  log_query(session_id)
91
 
92
- # Check dataset first
93
  response = find_in_dataset(message)
94
  if response is None:
95
  # Call Render API fallback
 
4
  import gradio as gr
5
  from datasets import load_dataset, Dataset, DatasetDict
6
  from huggingface_hub import HfFolder
7
+ from sentence_transformers import SentenceTransformer, util
8
+ import torch
9
+ import requests
10
 
11
  # ================================
12
  # CONFIG
13
  # ================================
14
+ MODEL_TOKEN = os.environ.get("HF_TOKEN") # for model usage
15
  DATASET_TOKEN = os.environ.get("dataset_HF_TOKEN") # for dataset updates
16
  DATASET_NAME = "guardian-ai-qna"
17
 
 
28
  try:
29
  dataset = load_dataset(DATASET_NAME, use_auth_token=DATASET_TOKEN)
30
  except:
31
+ dataset = DatasetDict({"train": Dataset.from_dict({"question": [], "answer": [], "embedding": []})})
32
+
33
+ # Load embedding model
34
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
35
 
36
  # ================================
37
  # HELPER FUNCTIONS
 
54
  now = datetime.now()
55
  user_queries.setdefault(user_id, []).append(now)
56
 
57
+ def find_in_dataset(question, threshold=0.75):
58
  if len(dataset["train"]) == 0:
59
  return None
60
+ # Compute embedding for input
61
+ question_emb = embed_model.encode(question, convert_to_tensor=True)
62
+ # Load existing embeddings
63
+ existing_embs = torch.tensor(dataset["train"]["embedding"]) if dataset["train"]["embedding"] else None
64
+ if existing_embs is None or len(existing_embs) == 0:
65
+ return None
66
+ # Compute cosine similarities
67
+ similarities = util.cos_sim(question_emb, existing_embs)[0]
68
+ max_score, idx = torch.max(similarities, dim=0)
69
+ if max_score >= threshold:
70
+ return dataset["train"]["answer"][idx.item()]
71
  return None
72
 
73
  def save_qna(question, answer):
74
  global dataset
75
+ # Compute embedding
76
+ emb = embed_model.encode(question).tolist()
77
+ new_entry = {"question": [question], "answer": [answer], "embedding": [emb]}
78
  new_ds = Dataset.from_dict(new_entry)
79
  dataset["train"] = dataset["train"].concatenate(new_ds)
80
  dataset["train"].push_to_hub(DATASET_NAME, token=DATASET_TOKEN)
 
84
  Replace this with your actual Render API call logic
85
  that fetches the answer from the internet.
86
  """
 
87
  RENDER_API_URL = os.environ.get("RENDER_API_URL")
88
  if not RENDER_API_URL:
89
  return "Render API not configured."
 
104
 
105
  log_query(session_id)
106
 
107
+ # Check dataset first (embedding-based)
108
  response = find_in_dataset(message)
109
  if response is None:
110
  # Call Render API fallback