drive_wise / app.py
ewingreen's picture
Create app.py
ad83769 verified
import gradio as gr
from huggingface_hub import InferenceClient
import difflib # for fuzzy matching
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
# ==== Load knowledge ====
with open("knowledge.txt", "r", encoding="utf-8") as f:
knowledge_text = f.read()
chunks = [chunk.strip() for chunk in knowledge_text.split("\n\n") if chunk.strip()]
embedder = SentenceTransformer('all-MiniLM-L6-v2')
chunk_embeddings = embedder.encode(chunks, convert_to_tensor=True)
def get_relevant_context(query, top_k=3):
query_embedding = embedder.encode(query, convert_to_tensor=True)
query_embedding = query_embedding / query_embedding.norm()
norm_chunk_embeddings = chunk_embeddings / chunk_embeddings.norm(dim=1, keepdim=True)
similarities = torch.matmul(norm_chunk_embeddings, query_embedding)
top_k_indices = torch.topk(similarities, k=top_k).indices.cpu().numpy()
return "\n\n".join([chunks[i] for i in top_k_indices])
# ==== Model ====
client = InferenceClient("google/gemma-2-2b-it")
def is_driving_related(message):
driving_keywords = [
"drive", "driving", "permit", "car", "road", "lane", "traffic",
"license", "parallel", "park", "stop sign", "brake", "accelerate",
"merge", "intersection", "seatbelt", "speed limit", "turn signal",
"pulled over", "parking", "roundabout"
]
message_words = message.lower().split()
for word in message_words:
for keyword in driving_keywords:
if difflib.SequenceMatcher(None, word, keyword).ratio() > 0.8:
return True
return False
def respond(message, history):
if not history:
if not is_driving_related(message):
return "Hey there! I’m Drive Wise 🚗 — I can only help with driving topics like road rules, parking tips, or permit prep. What driving question do you have for me?"
messages = [{"role": "system", "content": """You are Drive Wise, a friendly and supportive AI chatbot designed to help new drivers learn essential driving skills and traffic laws. Your goal is to make learning to drive simple, confident, and stress-free."""}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = ""
# for msg in client.chat_completion(messages, max_tokens=500, temperature=.1, stream=True):
# token = msg.choices[0].delta.content
# if token:
# response += token
# return response
# iterate through each message in the method
try:
for message in client.chat_completion(
messages,
max_tokens=500,
temperature=0.1,
stream=True
):
# add the tokens to the output content
token = message.choices[0].delta.content # capture the most recent token
response += token # Add it to the response
yield response # yield the response
except Exception as e:
print(f"An error occurred: {e}")
# === UI ===
about_text = """
## About this bot
This chatbot will help you learn more about driving!
Ask me about road rules, parking, your learner’s permit, or how to handle situations like being pulled over.
"""
with gr.Blocks() as demo:
gr.Markdown(about_text)
chatbot = gr.Chatbot(label="PERMIT TEST QUESTIONS")
with gr.Row():
msg = gr.Textbox(placeholder="Type your driving question here...")
send = gr.Button("Send")
def user_interaction(user_message, chat_history):
bot_message = respond(user_message, chat_history)
chat_history.append((user_message, bot_message))
return "", chat_history
send.click(user_interaction, inputs=[msg, chatbot], outputs=[msg, chatbot])
msg.submit(user_interaction, inputs=[msg, chatbot], outputs=[msg, chatbot])
# Fix for localhost issue
demo.launch(share=True)