|
|
| """Handler for chat endpoint operations."""
|
| import os
|
| import asyncio
|
| import logging
|
| from typing import Dict, List, Optional, Any
|
|
|
| from fastapi import HTTPException
|
|
|
| from src.constants import (
|
| MAX_CONTEXT_MESSAGES,
|
| DEFAULT_TEMPERATURE,
|
| DEFAULT_MAX_TOKENS,
|
| UPLOAD_DIR,
|
| )
|
| from core.models import ChatMessage
|
| from src.chat_helpers import extract_urls, model_supports_vision
|
| from src.document_processor import build_user_content, analyze_image_with_vl_result
|
| from src.youtube_handler import (
|
| is_youtube_url,
|
| extract_youtube_id,
|
| extract_transcript_async,
|
| format_transcript_for_context,
|
| fetch_youtube_comments,
|
| format_comments_for_context,
|
| YOUTUBE_INSTRUCTION_PROMPT,
|
| )
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| class ChatHandler:
|
| """Handles chat operations for both streaming and non-streaming endpoints."""
|
|
|
| def __init__(
|
| self,
|
| session_manager,
|
| memory_manager,
|
| chat_processor,
|
| research_handler,
|
| preset_manager,
|
| upload_handler,
|
| ):
|
| self.session_manager = session_manager
|
| self.memory_manager = memory_manager
|
| self.chat_processor = chat_processor
|
| self.research_handler = research_handler
|
| self.preset_manager = preset_manager
|
| self.upload_handler = upload_handler
|
|
|
|
|
|
|
|
|
|
|
| def validate_and_extract_preset(self, preset_id: Optional[str]) -> tuple:
|
| """Returns (temperature, max_tokens, preset_system_prompt, character_name)."""
|
| if preset_id and preset_id not in self.preset_manager.presets:
|
| raise HTTPException(400, f"Invalid preset_id: {preset_id}")
|
|
|
| temperature = DEFAULT_TEMPERATURE
|
| max_tokens = DEFAULT_MAX_TOKENS
|
| preset_system_prompt = None
|
| character_name = ""
|
|
|
| if preset_id and preset_id in self.preset_manager.presets:
|
| preset = self.preset_manager.presets[preset_id]
|
| if preset.get("enabled") is False:
|
| logger.info(f"Preset {preset_id} is disabled, using defaults")
|
| return temperature, max_tokens, preset_system_prompt, character_name
|
| if preset.get("system_prompt"):
|
| preset_system_prompt = preset["system_prompt"]
|
| character_name = preset.get("character_name", "")
|
| if character_name:
|
| name_line = f"Your name is {character_name}."
|
| if preset_system_prompt:
|
| preset_system_prompt = f"{name_line} {preset_system_prompt}"
|
| else:
|
| preset_system_prompt = name_line
|
| if "temperature" in preset:
|
| temperature = preset["temperature"]
|
| if "max_tokens" in preset:
|
| max_tokens = preset["max_tokens"]
|
|
|
| logger.info(f"Preset {preset_id}: temp={temperature}, max_tokens={max_tokens}")
|
| return temperature, max_tokens, preset_system_prompt, character_name
|
|
|
| def enhance_message_if_needed(self, message: str) -> str:
|
| """CoT enhancement disabled — modern models reason natively."""
|
| return message
|
|
|
|
|
|
|
|
|
|
|
| async def preprocess_message(
|
| self,
|
| message: str,
|
| att_ids: List[str],
|
| sess,
|
| auto_opened_docs: Optional[List[Dict[str, Any]]] = None,
|
| allow_tool_preprocessing: bool = True,
|
| ) -> tuple:
|
| """
|
| Common preprocessing for both chat endpoints.
|
|
|
| Returns (enhanced_message, user_content, text_for_context, youtube_transcripts, attachment_meta)
|
|
|
| If `auto_opened_docs` is provided, server-side document auto-creation
|
| (e.g. from an attached fillable PDF) appends entries describing the
|
| new doc so the caller can announce it to the frontend before streaming.
|
| """
|
| enhanced_message = message
|
| attachment_meta: List[Dict[str, Any]] = []
|
|
|
|
|
| urls = extract_urls(enhanced_message) if allow_tool_preprocessing else []
|
| youtube_transcripts: List[str] = []
|
|
|
| has_youtube = False
|
| for url in urls:
|
| if is_youtube_url(url):
|
| video_id = extract_youtube_id(url)
|
| if not video_id:
|
| continue
|
| has_youtube = True
|
| logger.info(f"Processing YouTube URL: {url}")
|
|
|
| transcript_task = extract_transcript_async(url, video_id)
|
| comments_task = fetch_youtube_comments(video_id)
|
| transcript_data, comments_data = await asyncio.gather(
|
| transcript_task, comments_task
|
| )
|
|
|
| title = comments_data.get("title", "")
|
| channel = comments_data.get("channel", "")
|
| youtube_transcripts.append(
|
| format_transcript_for_context(transcript_data, url, title, channel)
|
| )
|
| comments_ctx = format_comments_for_context(comments_data, url)
|
| if comments_ctx:
|
| youtube_transcripts.append(comments_ctx)
|
|
|
|
|
| if has_youtube:
|
| youtube_transcripts.insert(0, YOUTUBE_INSTRUCTION_PROMPT)
|
|
|
|
|
|
|
| files_by_id: Dict[str, Dict] = {}
|
| owner = getattr(sess, "owner", None)
|
| effective_att_ids = att_ids if allow_tool_preprocessing else []
|
| if effective_att_ids:
|
| for att_id in effective_att_ids:
|
| fi = self.upload_handler.resolve_upload(att_id, owner=owner)
|
| if fi:
|
| files_by_id[att_id] = fi
|
|
|
| for att_id in effective_att_ids:
|
| fi = files_by_id.get(att_id)
|
| if fi:
|
| attachment_meta.append({
|
| "id": fi["id"],
|
| "name": fi.get("name") or fi.get("original_name") or fi["id"],
|
| "mime": fi.get("mime", ""),
|
| "size": fi.get("size", 0),
|
| "width": fi.get("width"),
|
| "height": fi.get("height"),
|
| })
|
|
|
|
|
|
|
|
|
| vision_enabled = False
|
| main_is_vision = False
|
| if effective_att_ids:
|
| from src.settings import get_setting
|
| vision_enabled = get_setting("vision_enabled", True)
|
| if vision_enabled:
|
| main_is_vision = await asyncio.to_thread(
|
| model_supports_vision,
|
| sess.model or "",
|
| getattr(sess, "endpoint_url", "") or "",
|
| )
|
|
|
| if effective_att_ids and vision_enabled:
|
| meta_by_id = {m["id"]: m for m in attachment_meta}
|
| for att_id in effective_att_ids:
|
| file_info = files_by_id.get(att_id)
|
| if file_info and self.upload_handler.is_image_file(
|
| file_info["name"], file_info.get("mime", "")
|
| ):
|
| if main_is_vision:
|
|
|
| enhanced_message = f"{enhanced_message}\n\n[Image attached: {file_info['name']}]"
|
| _m = meta_by_id.get(att_id)
|
| if _m is not None:
|
| _m["vision_model"] = sess.model or ""
|
|
|
|
|
|
|
|
|
|
|
| _vcache = os.path.join(UPLOAD_DIR, ".vision", att_id + ".txt")
|
| if os.path.exists(_vcache):
|
| try:
|
| with open(_vcache, encoding="utf-8") as _vf:
|
| _vtext = _vf.read().strip()
|
| if _vtext:
|
| enhanced_message += f"\n[User-corrected caption / OCR for this image — treat as authoritative]:\n{_vtext}"
|
| _m = meta_by_id.get(att_id)
|
| if _m is not None:
|
| _m["vision"] = _vtext
|
| except Exception:
|
| pass
|
| else:
|
|
|
|
|
|
|
|
|
| _vcache = os.path.join(UPLOAD_DIR, ".vision", att_id + ".txt")
|
| vl_desc = None
|
| vl_model = get_setting("vision_model", "") or ""
|
| if os.path.exists(_vcache):
|
| try:
|
| with open(_vcache, encoding="utf-8") as _vf:
|
| cached_desc = _vf.read().strip()
|
| if cached_desc and not cached_desc.startswith("["):
|
| vl_desc = cached_desc
|
| except Exception:
|
| vl_desc = None
|
| if not vl_desc:
|
| vl_result = analyze_image_with_vl_result(file_info["path"], owner=owner)
|
| vl_desc = vl_result.get("text", "")
|
| vl_model = vl_result.get("model", "")
|
| if vl_desc and not vl_desc.startswith("["):
|
| try:
|
| os.makedirs(os.path.join(UPLOAD_DIR, ".vision"), exist_ok=True)
|
| with open(_vcache, "w", encoding="utf-8") as _vf:
|
| _vf.write(vl_desc)
|
| except Exception:
|
| pass
|
| enhanced_message = f"{enhanced_message}\n\n[Image: {file_info['name']}]\n{vl_desc}"
|
|
|
|
|
|
|
| _m = meta_by_id.get(att_id)
|
| if _m is not None:
|
| _m["vision"] = vl_desc
|
| _m["vision_model"] = vl_model
|
|
|
| user_content = build_user_content(
|
| enhanced_message, effective_att_ids, UPLOAD_DIR, self.upload_handler,
|
| session_id=getattr(sess, "id", None),
|
| auto_opened_docs=auto_opened_docs,
|
| owner=owner,
|
| resolved_uploads=files_by_id,
|
| )
|
|
|
|
|
| if not vision_enabled and isinstance(user_content, list):
|
| text_parts = [
|
| item.get("text", "") for item in user_content
|
| if isinstance(item, dict) and item.get("type") == "text"
|
| ]
|
| user_content = "\n".join(text_parts).strip() if text_parts else enhanced_message
|
| elif not main_is_vision and isinstance(user_content, list):
|
| text_parts = [
|
| item.get("text", "") for item in user_content
|
| if isinstance(item, dict) and item.get("type") == "text"
|
| ]
|
| user_content = "\n".join(text_parts).strip() if text_parts else enhanced_message
|
|
|
|
|
| if isinstance(user_content, list):
|
| text_for_context = next(
|
| (item["text"] for item in user_content if item.get("type") == "text"),
|
| enhanced_message,
|
| )
|
| else:
|
| text_for_context = user_content
|
|
|
| return enhanced_message, user_content, text_for_context, youtube_transcripts, attachment_meta
|
|
|
|
|
|
|
|
|
|
|
| def update_session_name_if_needed(self, session, message: str):
|
| if not session.name:
|
| derived = " ".join(message.split()[:5])
|
| session.name = "Chat: " + derived if derived else "Chat"
|
|
|
| def trim_history_if_needed(self, session):
|
| if len(session.history) > MAX_CONTEXT_MESSAGES:
|
| session.history = session.history[-MAX_CONTEXT_MESSAGES:]
|
|
|
| async def handle_memory_command(self, session, message: str) -> Optional[str]:
|
| """Process inline memory commands. Returns response string or None."""
|
| is_memory_cmd, memory_text = self.memory_manager.process_inline_memory_command(
|
| message
|
| )
|
| if is_memory_cmd and memory_text:
|
| mem = self.memory_manager.load()
|
| if not self.memory_manager.find_duplicates(memory_text, mem):
|
| new_entry = self.memory_manager.add_entry(memory_text)
|
| mem.append(new_entry)
|
| self.memory_manager.save(mem)
|
|
|
| session.add_message(ChatMessage("user", message))
|
| session.add_message(
|
| ChatMessage("assistant", f"Saved to memory: {memory_text}")
|
| )
|
|
|
| from src.database import update_session_last_accessed
|
|
|
| update_session_last_accessed(session.id)
|
| self.session_manager.save_sessions()
|
| return f"Saved to memory: {memory_text}"
|
| return None
|
|
|