ewingreen commited on
Commit
ad83769
·
verified ·
1 Parent(s): a42c4f7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import difflib # for fuzzy matching
4
+ from sentence_transformers import SentenceTransformer
5
+ import torch
6
+ import numpy as np
7
+
8
+ # ==== Load knowledge ====
9
+ with open("knowledge.txt", "r", encoding="utf-8") as f:
10
+ knowledge_text = f.read()
11
+
12
+ chunks = [chunk.strip() for chunk in knowledge_text.split("\n\n") if chunk.strip()]
13
+
14
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
15
+ chunk_embeddings = embedder.encode(chunks, convert_to_tensor=True)
16
+
17
+ def get_relevant_context(query, top_k=3):
18
+ query_embedding = embedder.encode(query, convert_to_tensor=True)
19
+ query_embedding = query_embedding / query_embedding.norm()
20
+ norm_chunk_embeddings = chunk_embeddings / chunk_embeddings.norm(dim=1, keepdim=True)
21
+ similarities = torch.matmul(norm_chunk_embeddings, query_embedding)
22
+ top_k_indices = torch.topk(similarities, k=top_k).indices.cpu().numpy()
23
+ return "\n\n".join([chunks[i] for i in top_k_indices])
24
+
25
+ # ==== Model ====
26
+ client = InferenceClient("google/gemma-2-2b-it")
27
+
28
+ def is_driving_related(message):
29
+ driving_keywords = [
30
+ "drive", "driving", "permit", "car", "road", "lane", "traffic",
31
+ "license", "parallel", "park", "stop sign", "brake", "accelerate",
32
+ "merge", "intersection", "seatbelt", "speed limit", "turn signal",
33
+ "pulled over", "parking", "roundabout"
34
+ ]
35
+ message_words = message.lower().split()
36
+ for word in message_words:
37
+ for keyword in driving_keywords:
38
+ if difflib.SequenceMatcher(None, word, keyword).ratio() > 0.8:
39
+ return True
40
+ return False
41
+
42
+ def respond(message, history):
43
+ if not history:
44
+ if not is_driving_related(message):
45
+ return "Hey there! I’m Drive Wise 🚗 — I can only help with driving topics like road rules, parking tips, or permit prep. What driving question do you have for me?"
46
+
47
+ messages = [{"role": "system", "content": """You are Drive Wise, a friendly and supportive AI chatbot designed to help new drivers learn essential driving skills and traffic laws. Your goal is to make learning to drive simple, confident, and stress-free."""}]
48
+
49
+ for user_msg, assistant_msg in history:
50
+ messages.append({"role": "user", "content": user_msg})
51
+ messages.append({"role": "assistant", "content": assistant_msg})
52
+
53
+ messages.append({"role": "user", "content": message})
54
+
55
+ response = ""
56
+
57
+ # for msg in client.chat_completion(messages, max_tokens=500, temperature=.1, stream=True):
58
+ # token = msg.choices[0].delta.content
59
+ # if token:
60
+ # response += token
61
+ # return response
62
+
63
+
64
+ # iterate through each message in the method
65
+ try:
66
+ for message in client.chat_completion(
67
+ messages,
68
+ max_tokens=500,
69
+ temperature=0.1,
70
+ stream=True
71
+ ):
72
+ # add the tokens to the output content
73
+ token = message.choices[0].delta.content # capture the most recent token
74
+ response += token # Add it to the response
75
+ yield response # yield the response
76
+
77
+ except Exception as e:
78
+ print(f"An error occurred: {e}")
79
+
80
+ # === UI ===
81
+ about_text = """
82
+ ## About this bot
83
+ This chatbot will help you learn more about driving!
84
+ Ask me about road rules, parking, your learner’s permit, or how to handle situations like being pulled over.
85
+ """
86
+
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown(about_text)
89
+
90
+ chatbot = gr.Chatbot(label="PERMIT TEST QUESTIONS")
91
+
92
+ with gr.Row():
93
+ msg = gr.Textbox(placeholder="Type your driving question here...")
94
+ send = gr.Button("Send")
95
+
96
+ def user_interaction(user_message, chat_history):
97
+ bot_message = respond(user_message, chat_history)
98
+ chat_history.append((user_message, bot_message))
99
+ return "", chat_history
100
+
101
+ send.click(user_interaction, inputs=[msg, chatbot], outputs=[msg, chatbot])
102
+ msg.submit(user_interaction, inputs=[msg, chatbot], outputs=[msg, chatbot])
103
+
104
+ # Fix for localhost issue
105
+ demo.launch(share=True)