princemaxp commited on
Commit
a8452a4
·
verified ·
1 Parent(s): be72f4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -69
app.py CHANGED
@@ -1,97 +1,91 @@
 
1
  import gradio as gr
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
- from datasets import load_dataset, Dataset, concatenate_datasets
4
- import os
5
-
6
- # -------------------------------
7
- # Config
8
- # -------------------------------
9
- HF_TOKEN = os.environ["dataset_HF_TOKEN"]
10
- DATASET_ID = "your-username/guardian-ai-qna" # replace with your HF username
11
- MODEL_ID = "google/gemma-2b-it"
12
 
13
- SYSTEM_PROMPT = """You are Guardian AI, a friendly cybersecurity educator.
14
- Your goal is to explain cybersecurity concepts in simple, engaging language with examples.
15
- Always keep answers clear, short, and focused on security awareness.
16
- Use the examples from the Q&A memory to improve your answers.
17
- """
18
 
19
- # -------------------------------
20
- # Load model & tokenizer
21
- # -------------------------------
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
23
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
24
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1)
 
25
 
26
- # -------------------------------
27
- # Dataset functions
28
- # -------------------------------
29
- def load_qna_dataset():
30
- try:
31
- dataset = load_dataset(DATASET_ID, use_auth_token=HF_TOKEN)["train"]
32
- except:
33
- dataset = Dataset.from_dict({"question": [], "answer": []})
34
- return dataset
35
 
36
- def save_qna(user_input, response):
37
- dataset = load_qna_dataset()
38
- new_entry = Dataset.from_dict({"question": [user_input], "answer": [response]})
39
- dataset = concatenate_datasets([dataset, new_entry])
40
- dataset.push_to_hub(DATASET_ID, token=HF_TOKEN)
41
 
42
- def retrieve_similar_qna(user_input, top_k=3):
43
- dataset = load_qna_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  if len(dataset) == 0:
45
  return ""
46
- # Simple keyword-based retrieval
47
- # You can upgrade to semantic search later
48
- relevant = []
49
- for q, a in zip(dataset["question"], dataset["answer"]):
50
- if any(word in user_input.lower() for word in q.lower().split()):
51
- relevant.append(f"Q: {q}\nA: {a}")
52
- if len(relevant) >= top_k:
53
- break
54
- return "\n".join(relevant)
55
 
56
- # -------------------------------
57
- # Chat function
58
- # -------------------------------
59
  def chat(history, user_input):
60
- # Retrieve past Q&A for context
61
  context = retrieve_similar_qna(user_input)
62
  prompt = SYSTEM_PROMPT
63
  if context:
64
  prompt += f"\n\nMemory of past Q&A:\n{context}"
65
  prompt += f"\n\nUser: {user_input}\nGuardian AI:"
66
 
67
- result = generator(
68
- prompt,
69
- max_new_tokens=200,
70
- do_sample=True,
71
- temperature=0.7,
72
- top_p=0.9
73
- )[0]["generated_text"]
 
74
 
75
  response = result.split("Guardian AI:")[-1].strip()
76
  history.append((user_input, response))
77
  save_qna(user_input, response)
78
  return history, history
79
 
80
- # -------------------------------
81
- # Gradio UI
82
- # -------------------------------
83
- with gr.Blocks() as demo:
84
- gr.Markdown("## 🛡️ Guardian AI – Cybersecurity Educator")
85
- chatbot = gr.Chatbot(type="messages") # Updated type to avoid deprecation warning
86
  state = gr.State([])
87
-
88
  with gr.Row():
89
- with gr.Column(scale=8):
90
- user_input = gr.Textbox(show_label=False, placeholder="Ask me about cybersecurity...")
91
- with gr.Column(scale=2):
92
- send_btn = gr.Button("Send")
93
 
94
- send_btn.click(chat, [state, user_input], [chatbot, state])
95
- user_input.submit(chat, [state, user_input], [chatbot, state])
96
 
97
- demo.launch()
 
1
+ import os
2
  import gradio as gr
3
+ import torch
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+ from datasets import load_dataset, Dataset
 
 
 
 
 
 
 
 
6
 
7
+ # ---------- CONFIG ----------
8
+ MODEL_ID = "YOUR_MODEL_ID_HF" # Replace with your HF model ID
9
+ DATASET_NAME = "guardian-ai-qna"
10
+ SYSTEM_PROMPT = "You are Guardian AI, a cybersecurity expert. Answer concisely."
 
11
 
12
+ # ---------- LOAD TOKENIZER & MODEL ----------
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
14
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
15
+ device = 0 if torch.cuda.is_available() else -1
16
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
17
 
18
+ # ---------- LOAD DATASET ----------
19
+ try:
20
+ dataset = load_dataset("huggingface", DATASET_NAME, split="train")
21
+ except:
22
+ dataset = Dataset.from_dict({"question": [], "answer": []})
 
 
 
 
23
 
24
+ # ---------- EMBEDDING HELPER ----------
25
+ from sentence_transformers import SentenceTransformer, util
26
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
 
 
27
 
28
+ # Cache embeddings in memory
29
+ if len(dataset) > 0:
30
+ dataset_embeddings = embedder.encode(dataset["question"], convert_to_tensor=True)
31
+ else:
32
+ dataset_embeddings = []
33
+
34
+ # ---------- SAVE QNA FUNCTION ----------
35
+ def save_qna(question, answer):
36
+ global dataset, dataset_embeddings
37
+ new_entry = Dataset.from_dict({"question": [question], "answer": [answer]})
38
+ dataset = Dataset.from_dict({
39
+ "question": dataset["question"] + [question],
40
+ "answer": dataset["answer"] + [answer]
41
+ })
42
+ # update embeddings
43
+ dataset_embeddings.append(embedder.encode(question, convert_to_tensor=True))
44
+ # push to HF dataset
45
+ dataset.push_to_hub(DATASET_NAME, token=os.environ.get("HF_TOKEN"))
46
+
47
+ # ---------- RETRIEVE SIMILAR QNA ----------
48
+ def retrieve_similar_qna(query, top_k=3):
49
  if len(dataset) == 0:
50
  return ""
51
+ query_emb = embedder.encode(query, convert_to_tensor=True)
52
+ similarities = util.cos_sim(query_emb, dataset_embeddings)[0]
53
+ top_results = similarities.topk(k=min(top_k, len(similarities)))
54
+ context = ""
55
+ for idx in top_results.indices:
56
+ context += f"Q: {dataset[idx]['question']}\nA: {dataset[idx]['answer']}\n"
57
+ return context
 
 
58
 
59
+ # ---------- CHAT FUNCTION ----------
 
 
60
  def chat(history, user_input):
 
61
  context = retrieve_similar_qna(user_input)
62
  prompt = SYSTEM_PROMPT
63
  if context:
64
  prompt += f"\n\nMemory of past Q&A:\n{context}"
65
  prompt += f"\n\nUser: {user_input}\nGuardian AI:"
66
 
67
+ with torch.no_grad():
68
+ result = generator(
69
+ prompt,
70
+ max_new_tokens=150,
71
+ do_sample=True,
72
+ temperature=0.6,
73
+ top_p=0.85
74
+ )[0]["generated_text"]
75
 
76
  response = result.split("Guardian AI:")[-1].strip()
77
  history.append((user_input, response))
78
  save_qna(user_input, response)
79
  return history, history
80
 
81
+ # ---------- GRADIO APP ----------
82
+ with gr.Blocks() as app:
83
+ chatbot = gr.Chatbot()
 
 
 
84
  state = gr.State([])
 
85
  with gr.Row():
86
+ user_msg = gr.Textbox(label="Type your message")
87
+ send_btn = gr.Button("Send")
 
 
88
 
89
+ send_btn.click(chat, [state, user_msg], [chatbot, state])
 
90
 
91
+ app.launch(share=True)