Spaces:
Sleeping
Sleeping
| import requests | |
| import json | |
| from typing import Any, List, Optional | |
| from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain_core.language_models.chat_models import SimpleChatModel | |
| from langchain_core.messages import BaseMessage | |
| class CustomChatModel(SimpleChatModel): | |
| """A custom chat model that calls a remote FastAPI endpoint.""" | |
| api_url: str | |
| def _llm_type(self) -> str: | |
| return "custom_chat_model" | |
| def _call( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| raw_prompt = messages[-1].content | |
| headers = {"Content-Type": "application/json"} | |
| data = {"prompt": raw_prompt} | |
| try: | |
| response = requests.post(self.api_url, headers=headers, data=json.dumps(data)) | |
| response.raise_for_status() | |
| result = response.json() | |
| # Case 1: Your backend returns a "response" field | |
| if "response" in result: | |
| full_text = result["response"] | |
| # Case 2: Backend just returns raw string | |
| elif isinstance(result, str): | |
| full_text = result | |
| # Case 3: Unexpected JSON -> stringify | |
| else: | |
| full_text = json.dumps(result) | |
| # Try to strip HuggingFace-style tags if present | |
| if "<|start_header_id|>" in full_text: | |
| parts = full_text.split("<|start_header_id|>assistant<|end_header_id|>\n\n") | |
| if len(parts) > 1: | |
| assistant_response = parts[1].replace("<|eot_id|>", "").strip() | |
| if assistant_response: | |
| return assistant_response | |
| # Fallback: just clean & return | |
| clean_response = full_text.strip() | |
| if not clean_response: | |
| raise ValueError("Model returned an empty response.") | |
| return clean_response | |
| except (requests.exceptions.RequestException, ValueError, json.JSONDecodeError) as e: | |
| print(f"❌ Custom model failed: {e}. Attempting fallback.") | |
| raise |