Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,11 +4,14 @@ from datetime import datetime, timedelta
|
|
| 4 |
import gradio as gr
|
| 5 |
from datasets import load_dataset, Dataset, DatasetDict
|
| 6 |
from huggingface_hub import HfFolder
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
# ================================
|
| 9 |
# CONFIG
|
| 10 |
# ================================
|
| 11 |
-
MODEL_TOKEN = os.environ.get("HF_TOKEN")
|
| 12 |
DATASET_TOKEN = os.environ.get("dataset_HF_TOKEN") # for dataset updates
|
| 13 |
DATASET_NAME = "guardian-ai-qna"
|
| 14 |
|
|
@@ -25,7 +28,10 @@ HfFolder.save_token(DATASET_TOKEN)
|
|
| 25 |
try:
|
| 26 |
dataset = load_dataset(DATASET_NAME, use_auth_token=DATASET_TOKEN)
|
| 27 |
except:
|
| 28 |
-
dataset = DatasetDict({"train": Dataset.from_dict({"question": [], "answer": []})})
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# ================================
|
| 31 |
# HELPER FUNCTIONS
|
|
@@ -48,17 +54,27 @@ def log_query(user_id):
|
|
| 48 |
now = datetime.now()
|
| 49 |
user_queries.setdefault(user_id, []).append(now)
|
| 50 |
|
| 51 |
-
def find_in_dataset(question):
|
| 52 |
if len(dataset["train"]) == 0:
|
| 53 |
return None
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return None
|
| 58 |
|
| 59 |
def save_qna(question, answer):
|
| 60 |
global dataset
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
new_ds = Dataset.from_dict(new_entry)
|
| 63 |
dataset["train"] = dataset["train"].concatenate(new_ds)
|
| 64 |
dataset["train"].push_to_hub(DATASET_NAME, token=DATASET_TOKEN)
|
|
@@ -68,7 +84,6 @@ def call_render(question):
|
|
| 68 |
Replace this with your actual Render API call logic
|
| 69 |
that fetches the answer from the internet.
|
| 70 |
"""
|
| 71 |
-
import requests
|
| 72 |
RENDER_API_URL = os.environ.get("RENDER_API_URL")
|
| 73 |
if not RENDER_API_URL:
|
| 74 |
return "Render API not configured."
|
|
@@ -89,7 +104,7 @@ def chat(history, message, session_id):
|
|
| 89 |
|
| 90 |
log_query(session_id)
|
| 91 |
|
| 92 |
-
# Check dataset first
|
| 93 |
response = find_in_dataset(message)
|
| 94 |
if response is None:
|
| 95 |
# Call Render API fallback
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
from datasets import load_dataset, Dataset, DatasetDict
|
| 6 |
from huggingface_hub import HfFolder
|
| 7 |
+
from sentence_transformers import SentenceTransformer, util
|
| 8 |
+
import torch
|
| 9 |
+
import requests
|
| 10 |
|
| 11 |
# ================================
|
| 12 |
# CONFIG
|
| 13 |
# ================================
|
| 14 |
+
MODEL_TOKEN = os.environ.get("HF_TOKEN") # for model usage
|
| 15 |
DATASET_TOKEN = os.environ.get("dataset_HF_TOKEN") # for dataset updates
|
| 16 |
DATASET_NAME = "guardian-ai-qna"
|
| 17 |
|
|
|
|
| 28 |
try:
|
| 29 |
dataset = load_dataset(DATASET_NAME, use_auth_token=DATASET_TOKEN)
|
| 30 |
except:
|
| 31 |
+
dataset = DatasetDict({"train": Dataset.from_dict({"question": [], "answer": [], "embedding": []})})
|
| 32 |
+
|
| 33 |
+
# Load embedding model
|
| 34 |
+
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 35 |
|
| 36 |
# ================================
|
| 37 |
# HELPER FUNCTIONS
|
|
|
|
| 54 |
now = datetime.now()
|
| 55 |
user_queries.setdefault(user_id, []).append(now)
|
| 56 |
|
| 57 |
+
def find_in_dataset(question, threshold=0.75):
|
| 58 |
if len(dataset["train"]) == 0:
|
| 59 |
return None
|
| 60 |
+
# Compute embedding for input
|
| 61 |
+
question_emb = embed_model.encode(question, convert_to_tensor=True)
|
| 62 |
+
# Load existing embeddings
|
| 63 |
+
existing_embs = torch.tensor(dataset["train"]["embedding"]) if dataset["train"]["embedding"] else None
|
| 64 |
+
if existing_embs is None or len(existing_embs) == 0:
|
| 65 |
+
return None
|
| 66 |
+
# Compute cosine similarities
|
| 67 |
+
similarities = util.cos_sim(question_emb, existing_embs)[0]
|
| 68 |
+
max_score, idx = torch.max(similarities, dim=0)
|
| 69 |
+
if max_score >= threshold:
|
| 70 |
+
return dataset["train"]["answer"][idx.item()]
|
| 71 |
return None
|
| 72 |
|
| 73 |
def save_qna(question, answer):
|
| 74 |
global dataset
|
| 75 |
+
# Compute embedding
|
| 76 |
+
emb = embed_model.encode(question).tolist()
|
| 77 |
+
new_entry = {"question": [question], "answer": [answer], "embedding": [emb]}
|
| 78 |
new_ds = Dataset.from_dict(new_entry)
|
| 79 |
dataset["train"] = dataset["train"].concatenate(new_ds)
|
| 80 |
dataset["train"].push_to_hub(DATASET_NAME, token=DATASET_TOKEN)
|
|
|
|
| 84 |
Replace this with your actual Render API call logic
|
| 85 |
that fetches the answer from the internet.
|
| 86 |
"""
|
|
|
|
| 87 |
RENDER_API_URL = os.environ.get("RENDER_API_URL")
|
| 88 |
if not RENDER_API_URL:
|
| 89 |
return "Render API not configured."
|
|
|
|
| 104 |
|
| 105 |
log_query(session_id)
|
| 106 |
|
| 107 |
+
# Check dataset first (embedding-based)
|
| 108 |
response = find_in_dataset(message)
|
| 109 |
if response is None:
|
| 110 |
# Call Render API fallback
|