rayymaxx's picture
Modified custom wrapper
92d190f
import requests
import json
from typing import Any, List, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import SimpleChatModel
from langchain_core.messages import BaseMessage
class CustomChatModel(SimpleChatModel):
"""A custom chat model that calls a remote FastAPI endpoint."""
api_url: str
@property
def _llm_type(self) -> str:
return "custom_chat_model"
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
raw_prompt = messages[-1].content
headers = {"Content-Type": "application/json"}
data = {"prompt": raw_prompt}
try:
response = requests.post(self.api_url, headers=headers, data=json.dumps(data))
response.raise_for_status()
result = response.json()
# Case 1: Your backend returns a "response" field
if "response" in result:
full_text = result["response"]
# Case 2: Backend just returns raw string
elif isinstance(result, str):
full_text = result
# Case 3: Unexpected JSON -> stringify
else:
full_text = json.dumps(result)
# Try to strip HuggingFace-style tags if present
if "<|start_header_id|>" in full_text:
parts = full_text.split("<|start_header_id|>assistant<|end_header_id|>\n\n")
if len(parts) > 1:
assistant_response = parts[1].replace("<|eot_id|>", "").strip()
if assistant_response:
return assistant_response
# Fallback: just clean & return
clean_response = full_text.strip()
if not clean_response:
raise ValueError("Model returned an empty response.")
return clean_response
except (requests.exceptions.RequestException, ValueError, json.JSONDecodeError) as e:
print(f"❌ Custom model failed: {e}. Attempting fallback.")
raise