Spaces:
Configuration error
Configuration error
| """ | |
| Interactive setup and chat interface for DeepDrone. | |
| """ | |
| import os | |
| import sys | |
| import asyncio | |
| from typing import Dict, Optional, Tuple, List | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.text import Text | |
| from rich.align import Align | |
| from rich.prompt import Prompt, Confirm | |
| from rich.table import Table | |
| from rich.live import Live | |
| from rich.layout import Layout | |
| from rich.spinner import Spinner | |
| from prompt_toolkit import prompt | |
| from prompt_toolkit.shortcuts import radiolist_dialog, input_dialog, message_dialog | |
| from prompt_toolkit.styles import Style | |
| import getpass | |
| from .config import ModelConfig | |
| from .drone_chat_interface import DroneChatInterface | |
| console = Console() | |
| # Provider configurations | |
| PROVIDERS = { | |
| "OpenAI": { | |
| "name": "openai", | |
| "models": ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"], | |
| "api_key_url": "https://platform.openai.com/api-keys", | |
| "description": "GPT models from OpenAI" | |
| }, | |
| "Anthropic": { | |
| "name": "anthropic", | |
| "models": ["claude-3-5-sonnet-20241022", "claude-3-sonnet-20240229", "claude-3-haiku-20240307"], | |
| "api_key_url": "https://console.anthropic.com/", | |
| "description": "Claude models from Anthropic" | |
| }, | |
| "Google": { | |
| "name": "vertex_ai", | |
| "models": ["gemini-1.5-pro", "gemini-1.5-flash", "gemini-pro"], | |
| "api_key_url": "https://console.cloud.google.com/", | |
| "description": "Gemini models from Google" | |
| }, | |
| "Meta": { | |
| "name": "openai", # Using OpenAI format for Llama models via providers | |
| "models": ["meta-llama/Meta-Llama-3.1-70B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct"], | |
| "api_key_url": "https://together.ai/ or https://replicate.com/", | |
| "description": "Llama models from Meta (via Together.ai/Replicate)" | |
| }, | |
| "Mistral": { | |
| "name": "mistral", | |
| "models": ["mistral-large-latest", "mistral-medium-latest", "mistral-small-latest"], | |
| "api_key_url": "https://console.mistral.ai/", | |
| "description": "Mistral AI models" | |
| }, | |
| "Ollama": { | |
| "name": "ollama", | |
| "models": ["llama3.1:latest", "codestral:latest", "qwen2.5-coder:latest", "phi3:latest"], | |
| "api_key_url": "https://ollama.ai/ (No API key needed - runs locally)", | |
| "description": "Local models via Ollama (no API key required)" | |
| } | |
| } | |
| def show_welcome_banner(): | |
| """Display the welcome banner.""" | |
| banner = """ | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β β | |
| β π DEEPDRONE AI CONTROL SYSTEM π β | |
| β β | |
| β Advanced Drone Control with AI Integration β | |
| β β | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| console.print(Panel( | |
| Align.center(Text(banner.strip(), style="bold green")), | |
| border_style="bright_green", | |
| padding=(1, 2) | |
| )) | |
| def select_provider() -> Optional[Tuple[str, Dict]]: | |
| """Interactive provider selection.""" | |
| console.print("\n[bold cyan]π‘ Select AI Provider[/bold cyan]\n") | |
| # Create provider table for display | |
| table = Table(show_header=True, header_style="bold magenta") | |
| table.add_column("β", style="bright_green", width=3) | |
| table.add_column("Provider", style="cyan", width=12) | |
| table.add_column("Description", style="white") | |
| table.add_column("Example Models", style="yellow") | |
| provider_list = list(PROVIDERS.items()) | |
| for i, (name, config) in enumerate(provider_list, 1): | |
| example_models = ", ".join(config["models"][:2]) | |
| if len(config["models"]) > 2: | |
| example_models += "..." | |
| table.add_row(str(i), name, config["description"], example_models) | |
| console.print(table) | |
| console.print() | |
| try: | |
| from rich.prompt import IntPrompt | |
| choice = IntPrompt.ask( | |
| "Select provider by number", | |
| choices=[str(i) for i in range(1, len(provider_list) + 1)], | |
| default=1 | |
| ) | |
| provider_name, provider_config = provider_list[choice - 1] | |
| return provider_name, provider_config | |
| except KeyboardInterrupt: | |
| console.print("\n[yellow]Selection cancelled[/yellow]") | |
| return None | |
| def get_available_ollama_models() -> List[str]: | |
| """Get list of locally available Ollama models.""" | |
| try: | |
| import ollama | |
| models = ollama.list() | |
| # The models are returned as Model objects with a 'model' attribute | |
| return [model.model for model in models.models] if hasattr(models, 'models') else [] | |
| except ImportError: | |
| return [] | |
| except Exception as e: | |
| # For debugging, you can uncomment the next line | |
| # print(f"Error getting Ollama models: {e}") | |
| return [] | |
| def install_ollama_model(model_name: str) -> bool: | |
| """Install an Ollama model.""" | |
| try: | |
| import ollama | |
| console.print(f"[yellow]π₯ Installing {model_name}... This may take a few minutes.[/yellow]") | |
| with Live( | |
| Spinner("dots", text=f"Installing {model_name}..."), | |
| console=console, | |
| transient=True | |
| ) as live: | |
| ollama.pull(model_name) | |
| live.stop() | |
| console.print(f"[green]β Successfully installed {model_name}[/green]") | |
| return True | |
| except ImportError: | |
| console.print("[red]β Ollama package not installed[/red]") | |
| return False | |
| except Exception as e: | |
| console.print(f"[red]β Failed to install {model_name}: {e}[/red]") | |
| return False | |
| def get_model_name(provider_name: str, provider_config: Dict) -> Optional[str]: | |
| """Get model name from user.""" | |
| console.print(f"\n[bold cyan]π€ Select Model for {provider_name}[/bold cyan]\n") | |
| # Special handling for Ollama | |
| if provider_name.lower() == "ollama": | |
| # Check if Ollama is running and get local models | |
| local_models = get_available_ollama_models() | |
| if local_models: | |
| console.print("[bold green]β Local Ollama models found:[/bold green]") | |
| for i, model in enumerate(local_models, 1): | |
| console.print(f" {i}. [green]{model}[/green]") | |
| console.print("\n[bold]Popular models (if not installed locally):[/bold]") | |
| start_idx = len(local_models) + 1 | |
| for i, model in enumerate(provider_config["models"], start_idx): | |
| console.print(f" {i}. [blue]{model}[/blue] [dim](will be downloaded)[/dim]") | |
| all_options = local_models + provider_config["models"] | |
| else: | |
| console.print("[yellow]β οΈ No local Ollama models found or Ollama not running[/yellow]") | |
| console.print("Make sure Ollama is running: [cyan]ollama serve[/cyan]\n") | |
| console.print("[bold]Popular models (will be downloaded):[/bold]") | |
| all_options = provider_config["models"] | |
| for i, model in enumerate(all_options, 1): | |
| console.print(f" {i}. [blue]{model}[/blue] [dim](will be downloaded)[/dim]") | |
| console.print(f"\n[dim]Download from: {provider_config['api_key_url']}[/dim]\n") | |
| try: | |
| from rich.prompt import Prompt | |
| result = Prompt.ask( | |
| "Enter model name or number from list above", | |
| default="1" | |
| ) | |
| if result: | |
| # Check if user entered a number (selecting from list) | |
| try: | |
| choice_num = int(result.strip()) | |
| if 1 <= choice_num <= len(all_options): | |
| selected_model = all_options[choice_num - 1] | |
| # Check if model needs to be installed | |
| if selected_model not in local_models: | |
| console.print(f"[yellow]Model '{selected_model}' not found locally.[/yellow]") | |
| from rich.prompt import Confirm | |
| if Confirm.ask(f"Would you like to install {selected_model}?", default=True): | |
| if install_ollama_model(selected_model): | |
| return selected_model | |
| else: | |
| return None | |
| else: | |
| console.print("[yellow]Model installation cancelled[/yellow]") | |
| return None | |
| return selected_model | |
| except ValueError: | |
| pass | |
| # User entered a custom model name | |
| model_name = result.strip() | |
| if model_name not in local_models: | |
| console.print(f"[yellow]Model '{model_name}' not found locally.[/yellow]") | |
| from rich.prompt import Confirm | |
| if Confirm.ask(f"Would you like to install {model_name}?", default=True): | |
| if install_ollama_model(model_name): | |
| return model_name | |
| else: | |
| return None | |
| else: | |
| console.print("[yellow]Model installation cancelled[/yellow]") | |
| return None | |
| return model_name | |
| return None | |
| except KeyboardInterrupt: | |
| console.print("\n[yellow]Input cancelled[/yellow]") | |
| return None | |
| else: | |
| # Standard handling for other providers | |
| console.print("[bold]Popular models for this provider:[/bold]") | |
| for i, model in enumerate(provider_config["models"], 1): | |
| console.print(f" {i}. [green]{model}[/green]") | |
| console.print(f"\n[dim]Get API key from: {provider_config['api_key_url']}[/dim]\n") | |
| try: | |
| from rich.prompt import Prompt | |
| result = Prompt.ask( | |
| "Enter model name or number from list above", | |
| default="1" | |
| ) | |
| if result: | |
| # Check if user entered a number (selecting from list) | |
| try: | |
| choice_num = int(result.strip()) | |
| if 1 <= choice_num <= len(provider_config["models"]): | |
| return provider_config["models"][choice_num - 1] | |
| except ValueError: | |
| pass | |
| # Return the entered model name | |
| return result.strip() | |
| return None | |
| except KeyboardInterrupt: | |
| console.print("\n[yellow]Input cancelled[/yellow]") | |
| return None | |
| def get_api_key(provider_name: str, model_name: str) -> Optional[str]: | |
| """Get API key from user.""" | |
| console.print(f"\n[bold cyan]π API Key for {provider_name}[/bold cyan]\n") | |
| console.print(f"Model: [green]{model_name}[/green]") | |
| console.print(f"Provider: [blue]{provider_name}[/blue]\n") | |
| # Ollama doesn't need an API key | |
| if provider_name.lower() == "ollama": | |
| console.print("[green]β Ollama runs locally - no API key required![/green]") | |
| console.print("[dim]Make sure Ollama is running: ollama serve[/dim]\n") | |
| return "local" # Return a placeholder value | |
| try: | |
| # Use getpass for secure password input (works in all environments) | |
| api_key = getpass.getpass("Enter your API key (hidden): ") | |
| if api_key and api_key.strip(): | |
| return api_key.strip() | |
| console.print("[yellow]No API key provided[/yellow]") | |
| return None | |
| except KeyboardInterrupt: | |
| console.print("\n[yellow]Input cancelled[/yellow]") | |
| return None | |
| def test_model_connection(model_config: ModelConfig) -> bool: | |
| """Test if the model configuration works.""" | |
| console.print(f"\n[yellow]π Testing connection to {model_config.name}...[/yellow]") | |
| try: | |
| from .llm_interface import LLMInterface | |
| with Live( | |
| Spinner("dots", text="Testing API connection..."), | |
| console=console, | |
| transient=True | |
| ) as live: | |
| llm = LLMInterface(model_config) | |
| result = llm.test_connection() | |
| live.stop() | |
| if result["success"]: | |
| console.print("[green]β Connection successful![/green]") | |
| console.print(f"[dim]Response: {result['response'][:100]}...[/dim]\n") | |
| return True | |
| else: | |
| console.print(f"[red]β Connection failed: {result['error']}[/red]\n") | |
| return False | |
| except Exception as e: | |
| console.print(f"[red]β Error testing connection: {e}[/red]\n") | |
| return False | |
| def start_interactive_session(): | |
| """Start the interactive setup and chat session.""" | |
| try: | |
| # Show welcome banner | |
| show_welcome_banner() | |
| # Step 1: Select provider | |
| console.print("[bold]Step 1: Choose your AI provider[/bold]\n") | |
| provider_result = select_provider() | |
| if not provider_result: | |
| console.print("[yellow]Setup cancelled. Goodbye![/yellow]") | |
| return | |
| provider_name, provider_config = provider_result | |
| # Step 2: Get model name | |
| console.print(f"[bold]Step 2: Select model for {provider_name}[/bold]") | |
| model_name = get_model_name(provider_name, provider_config) | |
| if not model_name: | |
| console.print("[yellow]Setup cancelled. Goodbye![/yellow]") | |
| return | |
| # Step 3: Get API key | |
| console.print("[bold]Step 3: Enter API key[/bold]") | |
| api_key = get_api_key(provider_name, model_name) | |
| if not api_key: | |
| console.print("[yellow]Setup cancelled. Goodbye![/yellow]") | |
| return | |
| # Create model configuration | |
| base_url = None | |
| if provider_name.lower() == "ollama": | |
| base_url = "http://localhost:11434" | |
| model_config = ModelConfig( | |
| name=f"{provider_name.lower()}-session", | |
| provider=provider_config["name"], | |
| model_id=model_name, | |
| api_key=api_key, | |
| base_url=base_url, | |
| max_tokens=2048, | |
| temperature=0.7 | |
| ) | |
| # Step 4: Test connection | |
| console.print("[bold]Step 4: Testing connection[/bold]") | |
| if not test_model_connection(model_config): | |
| if not Confirm.ask("Connection test failed. Continue anyway?"): | |
| console.print("[yellow]Setup cancelled. Goodbye![/yellow]") | |
| return | |
| # Step 5: Get drone connection string | |
| console.print("[bold yellow]π Drone Connection Setup[/bold yellow]\n") | |
| # Check if simulator is already running | |
| import subprocess | |
| try: | |
| result = subprocess.run(['ps', 'aux'], capture_output=True, text=True) | |
| if 'mavproxy' in result.stdout.lower() or 'sitl' in result.stdout.lower(): | |
| console.print("[green]β Detected running drone simulator![/green]") | |
| default_connection = "udp:127.0.0.1:14550" | |
| else: | |
| console.print("[yellow]β οΈ No simulator detected[/yellow]") | |
| default_connection = "udp:127.0.0.1:14550" | |
| except: | |
| default_connection = "udp:127.0.0.1:14550" | |
| console.print("Connection options:") | |
| console.print(" β’ [green]Simulator[/green]: [cyan]udp:127.0.0.1:14550[/cyan] (default)") | |
| console.print(" β’ [blue]Real Drone USB[/blue]: [cyan]/dev/ttyACM0[/cyan] (Linux) or [cyan]COM3[/cyan] (Windows)") | |
| console.print(" β’ [blue]Real Drone TCP[/blue]: [cyan]tcp:192.168.1.100:5760[/cyan]") | |
| console.print(" β’ [blue]Real Drone UDP[/blue]: [cyan]udp:192.168.1.100:14550[/cyan]\n") | |
| from rich.prompt import Prompt | |
| connection_string = Prompt.ask( | |
| "Enter drone connection string", | |
| default=default_connection | |
| ) | |
| if not connection_string: | |
| console.print("[yellow]Using default connection: udp:127.0.0.1:14550[/yellow]") | |
| connection_string = "udp:127.0.0.1:14550" | |
| console.print(f"[dim]Will connect to: {connection_string}[/dim]\n") | |
| # Step 6: Start chat | |
| console.print("[bold green]π Starting DeepDrone chat session...[/bold green]\n") | |
| # Small delay | |
| import time | |
| time.sleep(1) | |
| # Start the chat interface with the connection string | |
| chat_interface = DroneChatInterface(model_config, connection_string) | |
| chat_interface.start() | |
| except KeyboardInterrupt: | |
| console.print("\n[yellow]π DeepDrone session interrupted. Goodbye![/yellow]") | |
| sys.exit(0) | |
| except Exception as e: | |
| console.print(f"[red]β Error in interactive session: {e}[/red]") | |
| sys.exit(1) |