pragyarama's picture
Update app.py
08284e6 verified
import gradio as gr
from huggingface_hub import InferenceClient
from gradio.themes.utils import colors, fonts, sizes
from sentence_transformers import SentenceTransformer
import torch
import random
# STEP 1 FROM SEMANTIC SEARCH
from sentence_transformers import SentenceTransformer
import torch
# STEP 2 FROM SEMANTIC SEARCH
with open("formatted_travel_tips.txt", "r", encoding="utf-8") as file:
travel_text = file.read()
# STEP 3 FROM SEMANTIC SEARCH
def preprocess_text(text): # Clean raw text and split it into non-empty chunks
cleaned_text = text.strip()
chunks = cleaned_text.split("+")
cleaned_chunks = []
for chunk in chunks:
chunk = chunk.strip()
if chunk != "":
cleaned_chunks.append(chunk)
print(cleaned_chunks) # Make sure the correct file is being read
return cleaned_chunks
cleaned_chunks = preprocess_text(travel_text)
# STEP 4 FROM SEMANTIC SEARCH
model = SentenceTransformer('all-MiniLM-L6-v2') # Load pre-trained model for sentence embeddings
def create_embeddings(text_chunks): # Convert text chunks to vector embeddings
chunk_embeddings = model.encode(text_chunks, convert_to_tensor=True)
return chunk_embeddings
chunk_embeddings = create_embeddings(cleaned_chunks)
# STEP 5 FROM SEMANTIC SEARCH
def get_top_chunks(query, chunk_embeddings, text_chunks): # Return top 3 text chunks most semantically similar to the query
query_embedding = model.encode(query,convert_to_tensor=True)
query_embedding_normalized = query_embedding / query_embedding.norm()
chunk_embeddings_normalized = chunk_embeddings / chunk_embeddings.norm(dim=1, keepdim=True)
similarities = torch.matmul(chunk_embeddings_normalized,query_embedding_normalized)
print(similarities)
top_indices = torch.topk(similarities, k=1).indices
print(top_indices)
top_chunks = []
for i in top_indices:
chunk=text_chunks[i]
top_chunks.append(chunk)
return top_chunks
# STEP 6 FROM SEMANTIC SEARCH
top_results = get_top_chunks(
"Why is it important to carry copies of your travel documents?",
chunk_embeddings,
cleaned_chunks
)
print(top_results)
# HUGGING FACE PROJECT
client = InferenceClient("Qwen/Qwen2.5-72B-instruct")
def respond(message, history, name, country, season): # Generate a response using the most relevant travel info chunks
new_message = message + " " + country
top_chunks = get_top_chunks(new_message, chunk_embeddings, cleaned_chunks)
str_top_chunks = "\n".join(top_chunks)
messages = [
{
"role": "system",
"content": f"You are a friendly travel agent. Base your response on the following data: {top_chunks}. The user's name is {name} and they are traveling to {country} in the {season}. Don't create anything, only take from our data. Look for {country} in the data and pull the specific information from the data. Help the customer to plan outfits for their vacation and etiquette for the country. Additionally, if they have questions about food, reference the data to tell them about all the traditional and popular foods. Prompt the user about any diet they are on or if they have allergens if they ask about food."
},
{
"role": "user",
"content": f"Question: {message}"
}
]
if history:
messages.extend([{"role": "user", "content": m[0]} for m in history])
messages.extend([{"role": "assistant", "content": m[1]} for m in history])
messages.append({"role": "user", "content": message})
response = client.chat_completion(
messages,
max_tokens=500,
temperature=0.2
)
bot_reply = response["choices"][0]["message"]["content"].strip()
catchphrases = ["Trade fights for flights", "Red flag? Pack that bag", "No more relation, find that vacation", "Ditch the drama, find your nirvana", "Ditch the fights, catch the flights", "Relationship hiccup? Flight to Europe", "The only baggage you need is carry-on", "Skip the emotional drain to get on a plane"]
final_reply = random.choice(catchphrases) + "!\n\n" + bot_reply
return [message, final_reply], history + [[message, final_reply]]
# SURVEY
def submit_survey(name, country, season):
print(f"Survey received - Name: {name}, Country: {country}, Season: {season}")
return f"Thank you, {name}! Your survey response has been recorded. Ask Miles about any questions you have!", name, country, season
class MyCustomTheme(gr.themes.Base):
def __init__(self):
super().__init__(
primary_hue=colors.green,
neutral_hue=colors.Color(
name="custom_offwhite",
c50="#fcf1db",
c100="#fcf1db",
c200="#fcf1db",
c300="#fcf1db",
c400="#fcf1db",
c500="#fcf1db",
c600="#fcf1db",
c700="#fcf1db",
c800="#fcf1db",
c900="#fcf1db",
c950="#fcf1db",
),
font=fonts.GoogleFont("Poppins")
)
custom_theme = MyCustomTheme()
with gr.Blocks(theme=custom_theme, css="* { color: #0c265b !important; }") as demo:
name_state = gr.State()
country_state = gr.State()
season_state = gr.State()
chat_state = gr.State([])
gr.Markdown("## Survey")
with gr.Column():
name = gr.Textbox(label="Name", placeholder="Enter your name")
country = gr.Dropdown(label="Country", choices=["Argentina", "Australia", "Brazil", "Canada", "China", "Costa Rica", "Ecuador", "Egypt", "England", "France", "Germany", "Greece", "India", "Ireland", "Italy", "Jamaica", "Japan", "Kenya", "Malaysia", "Maldives", "Mexico", "Morocco", "Netherlands", "New Zealand", "Norway", "Panama", "Peru", "Philippines", "Poland", "Portugal", "Russia", "South Africa", "South Korea", "Spain", "Sweden", "Switzerland", "Taiwan", "Thailand", "Turkey", "USA", "Vietnam", "Zimbabwe"])
season = gr.Radio(label="Season", choices=["Spring", "Summer", "Autumn", "Winter"])
submit_btn = gr.Button("Submit Survey")
survey_output = gr.Textbox(label="Response", interactive=False)
submit_btn.click(
fn=submit_survey,
inputs=[name, country, season],
outputs=[survey_output, name_state, country_state, season_state]
)
gr.Markdown("---")
gr.Markdown("## Miles")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Hi, I'm Miles! Ask me anything about your trip: food, outfit, etiquette, or travel tips!")
send_btn = gr.Button("Send")
send_btn.click(
fn=respond,
inputs=[msg, chat_state, name_state, country_state, season_state],
outputs=[chat_state, chat_state]
).then(
fn=lambda history: history,
inputs=[chat_state],
outputs=chatbot
)
demo.launch()