import requests import json from typing import Any, List, Optional from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import SimpleChatModel from langchain_core.messages import BaseMessage class CustomChatModel(SimpleChatModel): """A custom chat model that calls a remote FastAPI endpoint.""" api_url: str @property def _llm_type(self) -> str: return "custom_chat_model" def _call( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: raw_prompt = messages[-1].content headers = {"Content-Type": "application/json"} data = {"prompt": raw_prompt} try: response = requests.post(self.api_url, headers=headers, data=json.dumps(data)) response.raise_for_status() result = response.json() # Case 1: Your backend returns a "response" field if "response" in result: full_text = result["response"] # Case 2: Backend just returns raw string elif isinstance(result, str): full_text = result # Case 3: Unexpected JSON -> stringify else: full_text = json.dumps(result) # Try to strip HuggingFace-style tags if present if "<|start_header_id|>" in full_text: parts = full_text.split("<|start_header_id|>assistant<|end_header_id|>\n\n") if len(parts) > 1: assistant_response = parts[1].replace("<|eot_id|>", "").strip() if assistant_response: return assistant_response # Fallback: just clean & return clean_response = full_text.strip() if not clean_response: raise ValueError("Model returned an empty response.") return clean_response except (requests.exceptions.RequestException, ValueError, json.JSONDecodeError) as e: print(f"❌ Custom model failed: {e}. Attempting fallback.") raise