Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, util | |
| import faiss | |
| import numpy as np | |
| from transformers import pipeline | |
| import time | |
| import ast | |
| import re | |
| # --- 1. DATA LOADING AND INITIALIZATION --- | |
| print("===== Application Startup =====") | |
| start_time = time.time() | |
| # Load the travel dataset and limit to the first 20,000 rows (same as index) | |
| print("Loading TravelPlanner dataset...") | |
| dataset = load_dataset("osunlp/TravelPlanner", "test") | |
| print("Dataset ready.") | |
| # --- 2. EMBEDDING AND RECOMMENDATION ENGINE --- | |
| print("Loading embedding model...") | |
| model_name = "all-mpnet-base-v2" | |
| embedding_model = SentenceTransformer(f"sentence-transformers/{model_name}") | |
| index_file = "trip_index.faiss" | |
| print(f"Loading FAISS index from {index_file}...") | |
| try: | |
| index = faiss.read_index(index_file) | |
| print(f"Index is ready. Total vectors in index: {index.ntotal}") | |
| except RuntimeError: | |
| print(f"Error: FAISS index file '{index_file}' not found.") | |
| print("Please run the `build_index.py` script first to create the index.") | |
| exit() | |
| # --- 3. SYNTHETIC GENERATION --- | |
| def format_plan_details(plan_string): | |
| """ | |
| Parses and formats the raw plan string from the dataset into readable Markdown. | |
| """ | |
| # If the plan is not in the expected dictionary format, return it as is. | |
| if not plan_string or not plan_string.strip().startswith('['): | |
| return plan_string | |
| try: | |
| # Safely parse the string representation of a list of dictionaries | |
| plan_list = ast.literal_eval(plan_string) | |
| except (ValueError, SyntaxError): | |
| # If parsing fails, return the original string to avoid crashing | |
| return plan_string | |
| formatted_sections = [] | |
| for section in plan_list: | |
| description = section.get('Description', 'Details') | |
| content = section.get('Content', '').strip() | |
| # Add a bold title for each section | |
| formatted_sections.append(f"#### {description}") | |
| # Use specific formatting based on the section's description | |
| if any(keyword in description for keyword in ['Attractions', 'Restaurants', 'Accommodations', 'Flight']): | |
| lines = content.split('\n') | |
| if lines: | |
| # Make the header bold | |
| formatted_sections.append(f"**{lines[0]}**") | |
| # Format the rest of the lines as a clean, bulleted list | |
| for item in lines[1:]: | |
| clean_item = ' '.join(item.split()) # Remove extra whitespace | |
| if clean_item: | |
| formatted_sections.append(f"- {clean_item}") | |
| elif 'Self-driving' in description or 'Taxi' in description: | |
| # Make simple travel descriptions more readable | |
| mode_emoji = "🚗" if 'Self-driving' in description else "🚕" | |
| formatted_sections.append(f"- {mode_emoji} {content.replace(', ', ', ')}") | |
| else: | |
| # Default formatting for any other type of content | |
| formatted_sections.append(content) | |
| # Add a newline for spacing between sections | |
| formatted_sections.append("") | |
| return "\n".join(formatted_sections) | |
| def get_recommendations_and_generate(query_text, k=3): | |
| # 1. Get Recommendations from existing data | |
| query_vector = embedding_model.encode([query_text]) | |
| query_vector = np.array(query_vector, dtype=np.float32) | |
| distances, indices = index.search(query_vector, k) | |
| results = [] | |
| for idx_numpy in indices[0]: | |
| idx = int(idx_numpy) | |
| trip_plan = { | |
| "dest": dataset['test']['dest'][idx], | |
| "days": dataset['test']['days'][idx], | |
| "reference_information": dataset['test']['reference_information'][idx] | |
| } | |
| results.append(trip_plan) | |
| while len(results) < 3: | |
| results.append({"dest": "No trip plan found", "days":"", "reference_information": ""}) | |
| # 2. Create a prompt for the generative model | |
| prompt = f"Write a complete travel plan that includes a title and a day-by-day itinerary. The trip must be about: {query_text}." | |
| print("Loading generative model...") | |
| generator = pipeline('text-generation', model='gpt2') | |
| # 3. Generate 10 new, creative trip ideas | |
| print("Generating 10 synthetic trip ideas...") | |
| generated_outputs = generator( | |
| prompt, | |
| max_new_tokens=250, # Increased tokens for more detailed plans | |
| num_return_sequences=10, | |
| pad_token_id=50256 | |
| ) | |
| # 4. Find the best trip out of the 10 generated | |
| print("Finding the most relevant generated trip...") | |
| generated_texts = [output['generated_text'].replace(prompt, "").strip() for output in generated_outputs] | |
| # Embed all 10 generated texts | |
| generated_embeddings = embedding_model.encode(generated_texts) | |
| # Calculate cosine similarity between the user's query and each generated text | |
| similarities = util.cos_sim(query_vector, generated_embeddings) | |
| # Find the index of the most similar generated trip | |
| best_recipe_index = np.argmax(similarities) | |
| best_generated_trip = generated_texts[best_recipe_index] | |
| return results[0], results[1], results[2], best_generated_trip | |
| # --- 4. GRADIO USER INTERFACE --- | |
| def format_trip_plan(trip): | |
| # Formats the recommended trips with markdown | |
| if not trip or 'reference_information' not in trip: | |
| return "### No similar trip plan found." | |
| formatted_plan = format_plan_details(trip['reference_information']) | |
| return f"### {trip['days']}-days trip to {trip['dest'].upper()}\n**Suggested Plan:**\n{formatted_plan}" | |
| def format_generated_trip(trip_text): | |
| return trip_text | |
| def trip_planner_wizard(destination, days): | |
| # Combine user inputs into a single query for processing | |
| days = int(days) # Ensure days is an integer for the f-string | |
| query_text = f"a {days}-day trip to {destination}" | |
| rec1, rec2, rec3, gen_rec_text = get_recommendations_and_generate(query_text) | |
| return format_trip_plan(rec1), format_trip_plan(rec2), format_trip_plan(rec3), format_generated_trip(gen_rec_text) | |
| end_time = time.time() | |
| print(f"Models and data loaded in {end_time - start_time:.2f} seconds.") | |
| # Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# ✈️ TripPlanner AI") | |
| gr.Markdown("Enter your destination and desired trip length, and get plan recommendations plus a new AI-generated idea!") | |
| with gr.Row(): | |
| destination_input = gr.Textbox(label="Destination", placeholder="e.g., Paris") | |
| days_input = gr.Number(label="Number of Days", value=3) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Get Trip Plans", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Recommended Trip Plans from Dataset") | |
| output_rec1 = gr.Markdown() | |
| output_rec2 = gr.Markdown() | |
| output_rec3 = gr.Markdown() | |
| with gr.Column(scale=1): | |
| gr.Markdown("### ✨ New AI-Generated Idea") | |
| output_gen = gr.Textbox(label="AI Generated Trip Plan", lines=20, interactive=False) | |
| submit_btn.click( | |
| fn=trip_planner_wizard, | |
| inputs=[destination_input, days_input], | |
| outputs=[output_rec1, output_rec2, output_rec3, output_gen] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Paris", 3], | |
| ["Orlando", 7], | |
| ["Tokyo", 5], | |
| ["the Greek Islands", 10] | |
| ], | |
| inputs=[destination_input, days_input] | |
| ) | |
| demo.launch(ssr_mode=False) | |