Spaces:
Running
Running
| """ | |
| Atlas - Minimal VAD version based on Gradio's official pattern | |
| """ | |
| import gradio as gr | |
| import asyncio | |
| import logging | |
| import tempfile | |
| import numpy as np | |
| import wave | |
| import io | |
| import time | |
| import re | |
| import ast | |
| import json | |
| import os | |
| import sys | |
| import atexit | |
| import subprocess | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Tuple | |
| from services.mcp_client import MCPClient | |
| from services.audio_service import AudioService | |
| from services.llm_service import LLMService | |
| from services.screen_service import get_screen_service | |
| from config.settings import Settings | |
| from config.prompts import get_generic_prompt | |
| from openai import OpenAI | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # ============================================ | |
| # App State (like Gradio's official example) | |
| # ============================================ | |
| class AppState: | |
| stream: Optional[np.ndarray] = None | |
| sampling_rate: int = 0 | |
| pause_detected: bool = False | |
| started_talking: bool = False | |
| stopped: bool = False | |
| conversation: List[Dict] = field(default_factory=list) | |
| # ============================================ | |
| # VAD Helper | |
| # ============================================ | |
| def detect_pause(audio: np.ndarray, sr: int, state: AppState) -> bool: | |
| """Simple energy-based pause detection.""" | |
| if audio is None or len(audio) < sr * 0.3: | |
| return False | |
| # Look at last 0.5 seconds | |
| window = int(sr * 0.5) | |
| recent = audio[-window:] if len(audio) >= window else audio | |
| # Energy | |
| recent_float = recent.astype(np.float32) | |
| if recent.dtype == np.int16: | |
| recent_float = recent_float / 32768.0 | |
| energy = float(np.sqrt(np.mean(recent_float ** 2))) | |
| SILENCE_THRESHOLD = 0.01 | |
| # If earlier was loud and now quiet = pause | |
| if len(audio) > window * 2: | |
| earlier = audio[:-window] | |
| earlier_float = earlier.astype(np.float32) | |
| if earlier.dtype == np.int16: | |
| earlier_float = earlier_float / 32768.0 | |
| earlier_energy = float(np.sqrt(np.mean(earlier_float ** 2))) | |
| if earlier_energy > SILENCE_THRESHOLD * 2 and energy < SILENCE_THRESHOLD: | |
| logger.info(f"Pause: earlier={earlier_energy:.4f}, now={energy:.4f}") | |
| return True | |
| return False | |
| def audio_to_wav_file(audio: np.ndarray, sr: int) -> str: | |
| """Save audio to temp WAV file.""" | |
| audio_float = audio.astype(np.float32) | |
| max_val = np.max(np.abs(audio_float)) | |
| if max_val > 0: | |
| audio_float = audio_float / max_val | |
| audio_int = (audio_float * 32767).astype(np.int16) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| with wave.open(tmp.name, 'wb') as w: | |
| w.setnchannels(1) | |
| w.setsampwidth(2) | |
| w.setframerate(sr) | |
| w.writeframes(audio_int.tobytes()) | |
| return tmp.name | |
| # ============================================ | |
| # MCP | |
| # ============================================ | |
| def start_mcp_server(): | |
| """ | |
| Start the local CRM MCP server (crm_mcp_server.py) in a background process. | |
| Controlled by Settings.mcp_auto_start (MCP_AUTO_START env var). | |
| """ | |
| settings = Settings() | |
| if not getattr(settings, "mcp_auto_start", True): | |
| logger.info("MCP auto-start disabled via settings.") | |
| return None | |
| script_path = os.path.join(os.path.dirname(__file__), "crm_mcp_server.py") | |
| cmd = [sys.executable, script_path] | |
| try: | |
| proc = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| logger.info(f"Started CRM MCP server (PID={proc.pid}) using: {cmd}") | |
| except Exception as e: | |
| logger.error(f"Failed to start CRM MCP server: {e}") | |
| return None | |
| # Ensure child process is cleaned up when app exits | |
| def _cleanup(): | |
| if proc.poll() is None: | |
| logger.info("Stopping CRM MCP server...") | |
| try: | |
| proc.terminate() | |
| except Exception: | |
| pass | |
| atexit.register(_cleanup) | |
| return proc | |
| # ============================================ | |
| # Chatbot | |
| # ============================================ | |
| TOOL_CALL_RE = re.compile( | |
| r'^\s*([a-zA-Z_][\w]*)\s*\((.*)\)\s*$', re.DOTALL | |
| ) | |
| def parse_tool_call(text: str): | |
| """ | |
| Extract tool_name and kwargs from something like: | |
| tool_name(a=1, b="x") | |
| Works even if surrounded by chatter or code fences. | |
| """ | |
| # Remove code fences | |
| cleaned = text.strip() | |
| if "```" in cleaned: | |
| parts = cleaned.split("```") | |
| if len(parts) >= 2: | |
| cleaned = parts[1] | |
| # Find last candidate line | |
| pattern = re.compile(r'^([a-zA-Z_]\w*)\s*\((.*)\)\s*$') | |
| for line in reversed(cleaned.splitlines()): | |
| line = line.strip() | |
| m = pattern.match(line) | |
| if not m: | |
| continue | |
| print(f"Tool call: {line}") | |
| name, args_src = m.groups() | |
| args_src = args_src.strip() | |
| # No args | |
| if not args_src: | |
| return name, {} | |
| try: | |
| func_src = f"def _f({args_src}): pass" | |
| module = ast.parse(func_src) | |
| func_def = module.body[0] # ast.FunctionDef | |
| args = func_def.args | |
| kwargs = {} | |
| for arg, default in zip(args.args, args.defaults): | |
| key = arg.arg | |
| value = ast.literal_eval(default) | |
| kwargs[key] = value | |
| return name, kwargs | |
| except Exception as e: | |
| print("Argument parse error:", e) | |
| return None | |
| return None | |
| class Chatbot: | |
| def __init__(self): | |
| self.settings = Settings() | |
| self.audio_service = AudioService( | |
| api_key=self.settings.hf_token, | |
| stt_provider="fal-ai", | |
| stt_model=self.settings.stt_model, | |
| tts_model=self.settings.tts_model, | |
| ) | |
| self.llm_service = LLMService( | |
| api_key=self.settings.llm_api_key, | |
| model_name=self.settings.effective_model_name, | |
| ) | |
| self.vision_client = OpenAI( | |
| base_url=self.settings.NEBIUS_BASE_URL, | |
| api_key=self.settings.NEBIUS_API_KEY | |
| ) | |
| self.vision_model = self.settings.NEBIUS_MODEL | |
| self.screen_service = get_screen_service() | |
| self.history: list[dict] = [] | |
| self.mcp = MCPClient() | |
| try: | |
| self.tools = self.mcp.list_tools() | |
| except Exception as e: | |
| # fail gracefully, tools just wonβt be used | |
| logging.exception("Failed to load tools from MCP server: %s", e) | |
| self.tools = [] | |
| self.tools_description = self._build_tools_description() | |
| def _build_tools_description(self) -> str: | |
| """Build a human-readable list of tools for the system prompt.""" | |
| if not getattr(self, "tools", None): | |
| return "No tools are currently available." | |
| lines = [] | |
| for t in self.tools: | |
| name = t.get("name", "unknown_tool") | |
| desc = t.get("description", "") | |
| props = t.get("inputSchema", {}).get("properties", {}) | |
| args = ", ".join( | |
| f'{k}: {v.get("type", "string")}' | |
| for k, v in props.items() | |
| ) | |
| lines.append(f"- {name}({args}) β {desc}") | |
| return "\n".join(lines) | |
| async def process(self, text: str, tts_enabled: bool = True) -> Tuple[str, Optional[str]]: | |
| if not text.strip(): | |
| return "", None | |
| # ---------- Phase 1: ask model what to do ---------- | |
| messages = self.llm_service.build_messages_with_tools( | |
| system_prompt=get_generic_prompt(), | |
| user_input=text, | |
| tools_description=self.tools_description, | |
| conversation_history=self.history, | |
| ) | |
| first_reply = await self.llm_service.get_chat_completion(messages) | |
| # Try to parse a tool call from the reply | |
| tool_call = parse_tool_call(first_reply) | |
| tool_result_str = None | |
| if tool_call: | |
| tool_name, tool_args = tool_call | |
| try: | |
| result = self.mcp.call_tool(tool_name, tool_args) | |
| tool_result_str = ( | |
| f"Tool {tool_name} succeeded with arguments {tool_args}.\n" | |
| f"Result (JSON):\n{json.dumps(result, indent=2)}" | |
| ) | |
| except Exception as e: | |
| tool_result_str = f"Tool {tool_name} failed: {e}" | |
| # ---------- Phase 2: give tool result back to model ---------- | |
| messages = self.llm_service.build_messages_with_tools( | |
| system_prompt=get_generic_prompt(), | |
| user_input=text, | |
| tools_description=self.tools_description, | |
| conversation_history=self.history, | |
| tool_results=tool_result_str, | |
| ) | |
| reply = await self.llm_service.get_chat_completion(messages) | |
| else: | |
| # No tool call β just treat initial text as final answer | |
| reply = first_reply | |
| # Save final user + assistant messages in conversation history | |
| self.history.append({"role": "user", "content": text}) | |
| self.history.append({"role": "assistant", "content": reply}) | |
| # ---------- Optional: TTS ---------- | |
| audio_path = None | |
| if tts_enabled: | |
| audio_bytes = await self.audio_service.text_to_speech(reply) | |
| if audio_bytes: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| tmp.write(audio_bytes) | |
| tmp.close() | |
| audio_path = tmp.name | |
| return reply, audio_path | |
| async def transcribe(self, audio_path: str) -> str: | |
| return await self.audio_service.speech_to_text(audio_path) | |
| async def capture_screen(self, state: AppState, tts_enabled: bool) -> Tuple[List[Dict], Optional[str], AppState, str]: | |
| """Capture screen and send to vision model.""" | |
| # Capture screenshot | |
| capture = self.screen_service.capture() | |
| if not capture: | |
| return state.conversation, None, state, "β Capture failed" | |
| logger.info(f"Screenshot captured: {capture.width}x{capture.height}") | |
| # Send to vision model | |
| try: | |
| response = await asyncio.get_event_loop().run_in_executor( | |
| None, | |
| lambda: self.vision_client.chat.completions.create( | |
| model=self.vision_model, | |
| messages=[{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "Ignore the ATLAS interface on the left-most side of the screen. Describe the content of the main window."}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{capture.image_b64}"}} | |
| ] | |
| }] | |
| ) | |
| ) | |
| analysis = response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Vision error: {e}") | |
| return state.conversation, None, state, f"β {str(e)}" | |
| # Process through main LLM | |
| reply, audio_path = await self.process(f"[Screen analysis]: {analysis}", tts_enabled) | |
| state.conversation.append({"role": "user", "content": "[Screen captured]"}) | |
| state.conversation.append({"role": "assistant", "content": f"Vision output: {reply}. Incorporate the data in your context. Do not call any tool yet, await for further instructions."}) | |
| return state.conversation, audio_path, state, "πΈ Done" | |
| # Start MCP server as soon as this module is imported | |
| mcp_proc = start_mcp_server() | |
| # Give it a brief moment to bind to the port | |
| time.sleep(1.0) | |
| bot = None # type: ignore | |
| def tokens_present() -> bool: | |
| """Check whether required env vars are already available.""" | |
| return bool(os.getenv("HF_TOKEN")) and bool(os.getenv("NEBIUS_API_KEY")) | |
| ENV_PATH = Path(__file__).parent / ".env" | |
| def upsert_env_var(key: str, value: str): | |
| """ | |
| Update or append env var in .env file so it persists across runs. | |
| Simple key=value per line, no fancy parsing. | |
| """ | |
| if not value: | |
| return | |
| lines = [] | |
| if ENV_PATH.exists(): | |
| lines = ENV_PATH.read_text(encoding="utf-8").splitlines() | |
| found = False | |
| for i, line in enumerate(lines): | |
| if line.startswith(f"{key}="): | |
| lines[i] = f"{key}={value}" | |
| found = True | |
| break | |
| if not found: | |
| lines.append(f"{key}={value}") | |
| ENV_PATH.write_text("\n".join(lines) + "\n", encoding="utf-8") | |
| def ensure_bot_initialized() -> Optional[str]: | |
| """ | |
| Initialize the global Chatbot if tokens are present. | |
| Returns an error message if tokens are missing, otherwise None. | |
| """ | |
| global bot | |
| if bot is not None: | |
| return None | |
| hf_token = os.getenv("HF_TOKEN", "") | |
| if not hf_token or len(hf_token) <= 10: | |
| return "β οΈ HF_TOKEN missing or invalid. Please fill it in the Setup section." | |
| # Optional debug: see what we are about to use | |
| settings = Settings() | |
| logger.info( | |
| f"Initializing Chatbot with HF token prefix={settings.hf_token[:4]}..., len={len(settings.hf_token)}" | |
| ) | |
| bot = Chatbot() | |
| return None | |
| def save_tokens(hf_token: str, nebius_api_key: str) -> str: | |
| # basic sanity check | |
| if hf_token and not hf_token.strip().startswith("hf_"): | |
| return "β HF_TOKEN does not look like a Hugging Face token (should start with 'hf_')." | |
| if hf_token: | |
| os.environ["HF_TOKEN"] = hf_token.strip() | |
| upsert_env_var("HF_TOKEN", hf_token.strip()) | |
| if nebius_api_key: | |
| os.environ["NEBIUS_API_KEY"] = nebius_api_key.strip() | |
| upsert_env_var("NEBIUS_API_KEY", nebius_api_key.strip()) | |
| # NOW build Chatbot + LLMService with the *current* env | |
| err = ensure_bot_initialized() | |
| if err: | |
| return err | |
| return "β Tokens saved and assistant initialized. You can now use Atlas." | |
| def check_tokens_on_load(): | |
| if tokens_present(): | |
| # env already has HF_TOKEN/NEBIUS_API_KEY: build Chatbot immediately | |
| err = ensure_bot_initialized() | |
| msg = "β Tokens loaded from .env. Atlas is ready." if not err else err | |
| return ( | |
| gr.update(visible=False), # hf_token_box | |
| gr.update(visible=False), # nebius_key_box | |
| msg, | |
| ) | |
| else: | |
| return ( | |
| gr.update(visible=True), | |
| gr.update(visible=True), | |
| "β οΈ Please paste your HF_TOKEN and NEBIUS_API_KEY to start.", | |
| ) | |
| # ============================================ | |
| # Gradio Handlers | |
| # ============================================ | |
| def process_audio(audio: tuple, state: AppState): | |
| """Process audio chunk. Return gr.Audio(recording=False) to stop.""" | |
| if audio is None: | |
| return None, state | |
| sr, data = audio | |
| # Mono | |
| if data.ndim > 1: | |
| data = data.mean(axis=1) | |
| # Accumulate | |
| if state.stream is None: | |
| state.stream = data | |
| state.sampling_rate = sr | |
| else: | |
| state.stream = np.concatenate((state.stream, data)) | |
| # Energy check | |
| data_float = data.astype(np.float32) | |
| if data.dtype == np.int16: | |
| data_float = data_float / 32768.0 | |
| energy = float(np.sqrt(np.mean(data_float ** 2))) | |
| if energy > 0.015: | |
| state.started_talking = True | |
| logger.debug(f"Talking: energy={energy:.4f}") | |
| # Pause check | |
| state.pause_detected = detect_pause(state.stream, state.sampling_rate, state) | |
| if state.pause_detected and state.started_talking: | |
| logger.info("Pause detected - stopping recording") | |
| return gr.Audio(recording=False), state | |
| return None, state | |
| async def respond(state: AppState, tts_enabled: bool): | |
| """Transcribe and respond when recording stops.""" | |
| if bot is None: | |
| msg = "β οΈ Configure HF_TOKEN and NEBIUS_API_KEY in the Setup section before using voice." | |
| state.conversation.append({"role": "assistant", "content": msg}) | |
| return None, AppState(conversation=state.conversation), state.conversation | |
| if state.stream is None or len(state.stream) < 1000: | |
| logger.info("No audio") | |
| return None, AppState(conversation=state.conversation), state.conversation | |
| logger.info(f"Processing {len(state.stream)} samples...") | |
| wav_path = audio_to_wav_file(state.stream, state.sampling_rate) | |
| transcript = await bot.transcribe(wav_path) | |
| logger.info(f"Transcript: {transcript}") | |
| if not transcript.strip(): | |
| return None, AppState(conversation=state.conversation), state.conversation | |
| reply, audio_path = await bot.process(transcript, tts_enabled) | |
| state.conversation.append({"role": "user", "content": transcript}) | |
| state.conversation.append({"role": "assistant", "content": reply}) | |
| return audio_path, AppState(conversation=state.conversation), state.conversation | |
| def start_recording(state: AppState): | |
| """Restart recording.""" | |
| if not state.stopped: | |
| return gr.Audio(recording=True) | |
| return gr.Audio(recording=False) | |
| async def send_text(text: str, state: AppState, tts_enabled: bool): | |
| if not text.strip(): | |
| return state.conversation, None, state, "" | |
| if bot is None: | |
| msg = "β οΈ Configure HF_TOKEN and NEBIUS_API_KEY in the Setup section before chatting." | |
| state.conversation.append({"role": "assistant", "content": msg}) | |
| return state.conversation, None, state, "" | |
| reply, audio_path = await bot.process(text, tts_enabled) | |
| state.conversation.append({"role": "user", "content": text}) | |
| state.conversation.append({"role": "assistant", "content": reply}) | |
| return state.conversation, audio_path, state, "" | |
| async def capture_screen_handler(state: AppState, tts_enabled: bool): | |
| if bot is None: | |
| msg = "β οΈ Configure HF_TOKEN and NEBIUS_API_KEY in the Setup section before using screen capture." | |
| return state.conversation, None, state, msg | |
| return await bot.capture_screen(state, tts_enabled) | |
| # ============================================ | |
| # UI | |
| # ============================================ | |
| with gr.Blocks(title="ATLAS") as demo: | |
| gr.Markdown("### Atlas - CRM Voice Assistant") | |
| state = gr.State(value=AppState()) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Conversation", height=400) | |
| with gr.Row(): | |
| txt = gr.Textbox(placeholder="Type here your message...", label="Input", scale=4) | |
| send_btn = gr.Button("Send", scale=1) | |
| with gr.Column(scale=1): | |
| # π Setup section | |
| gr.Markdown("### Setup (API keys)") | |
| hf_token_box = gr.Textbox( | |
| placeholder="Paste your HuggingFace token (HF_TOKEN)", | |
| label="HF_TOKEN", | |
| type="password" | |
| ) | |
| nebius_key_box = gr.Textbox( | |
| placeholder="Paste your Nebius API key (NEBIUS_API_KEY)", | |
| label="NEBIUS_API_KEY", | |
| type="password" | |
| ) | |
| save_keys_btn = gr.Button("Save keys & initialize Atlas") | |
| setup_status = gr.Markdown("") | |
| gr.Markdown("---") | |
| gr.Markdown("### Speech module") | |
| mic = gr.Audio( | |
| sources=["microphone"], | |
| type="numpy", | |
| label="Microphone", | |
| streaming=True, | |
| ) | |
| audio_out = gr.Audio(label="Response", autoplay=True, streaming=True) | |
| tts_toggle = gr.Checkbox(label="π TTS Enabled", value=True) | |
| stop_btn = gr.Button("π Stop", variant="stop") | |
| gr.Markdown("---") | |
| gr.Markdown("### π₯οΈ Screen") | |
| capture_btn = gr.Button("πΈ Capture Screen") | |
| screen_status = gr.Textbox(label="Status", value="Ready", interactive=False) | |
| # Stream -> detect pause -> stop | |
| mic.stream( | |
| process_audio, | |
| inputs=[mic, state], | |
| outputs=[mic, state], | |
| stream_every=0.5, | |
| time_limit=60, | |
| ) | |
| # Stop -> transcribe -> respond -> restart | |
| mic.stop_recording( | |
| respond, | |
| inputs=[state, tts_toggle], | |
| outputs=[audio_out, state, chatbot], | |
| ).then( | |
| start_recording, | |
| inputs=[state], | |
| outputs=[mic], | |
| ) | |
| stop_btn.click( | |
| lambda: (AppState(stopped=True), gr.Audio(recording=False)), | |
| outputs=[state, mic], | |
| ) | |
| send_btn.click(send_text, inputs=[txt, state, tts_toggle], outputs=[chatbot, audio_out, state, txt]) | |
| txt.submit(send_text, inputs=[txt, state, tts_toggle], outputs=[chatbot, audio_out, state, txt]) | |
| # Screen capture | |
| capture_btn.click( | |
| capture_screen_handler, | |
| inputs=[state, tts_toggle], | |
| outputs=[chatbot, audio_out, state, screen_status] | |
| ) | |
| # When app loads, show/hide token inputs based on env | |
| demo.load( | |
| fn=check_tokens_on_load, | |
| inputs=None, | |
| outputs=[hf_token_box, nebius_key_box, setup_status], | |
| ) | |
| # When user clicks "Save keys" | |
| save_keys_btn.click( | |
| fn=save_tokens, | |
| inputs=[hf_token_box, nebius_key_box], | |
| outputs=[setup_status], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Default() | |
| ) |