|
|
import requests
|
|
|
from typing import List, Optional, Sequence, Any, AsyncGenerator
|
|
|
|
|
|
from llama_index.legacy.llms import LLM, LLMMetadata
|
|
|
from llama_index.legacy.llms.types import ChatMessage
|
|
|
from llama_index.core.llms.callbacks import llm_chat_callback, llm_completion_callback
|
|
|
from llama_index.core.base.llms.types import ChatMessage, ChatResponse, CompletionResponseAsyncGen, ChatResponseAsyncGen, MessageRole, CompletionResponse, CompletionResponseGen
|
|
|
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
|
|
|
|
|
|
|
|
class Kognie(LLM):
|
|
|
"""
|
|
|
A custom LLM that calls a FastAPI server at /text endpoint.
|
|
|
"""
|
|
|
base_url: str = 'http://api2.kognie.com'
|
|
|
api_key: str
|
|
|
model: str
|
|
|
response_format: str = 'url'
|
|
|
|
|
|
@property
|
|
|
def metadata(self) -> LLMMetadata:
|
|
|
|
|
|
return LLMMetadata(
|
|
|
model_name=self.model
|
|
|
)
|
|
|
|
|
|
def _generate_text(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
model: Optional[str] = None,
|
|
|
**kwargs
|
|
|
) -> str:
|
|
|
"""
|
|
|
The single-turn text generation method.
|
|
|
LlamaIndex calls `_generate_text` internally whenever it needs a completion.
|
|
|
"""
|
|
|
|
|
|
|
|
|
selected_model = model if model else self.model
|
|
|
|
|
|
endpoint = f"{self.base_url}/text"
|
|
|
|
|
|
|
|
|
params = {
|
|
|
"question": prompt,
|
|
|
"model": selected_model
|
|
|
}
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
"X-KEY": self.api_key
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = requests.get(endpoint, params=params, headers=headers)
|
|
|
response.raise_for_status()
|
|
|
except requests.HTTPError as exc:
|
|
|
raise ValueError(f"FastAPI /text endpoint error: {exc}") from exc
|
|
|
|
|
|
|
|
|
data = response.json()
|
|
|
text_output = data.get("response", "")
|
|
|
|
|
|
return text_output
|
|
|
|
|
|
def _generate_image(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
model: str,
|
|
|
response_format: str,
|
|
|
**kwargs
|
|
|
) -> str:
|
|
|
"""
|
|
|
The single-turn text generation method.
|
|
|
LlamaIndex calls `_generate_text` internally whenever it needs a completion.
|
|
|
"""
|
|
|
|
|
|
|
|
|
selected_model = model if model else self.model
|
|
|
|
|
|
endpoint = f"{self.base_url}/image"
|
|
|
|
|
|
|
|
|
params = {
|
|
|
"question": prompt,
|
|
|
"model": selected_model,
|
|
|
"response_format": response_format
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
"X-KEY": self.api_key
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
|
|
|
response = requests.get(endpoint, params=params, headers=headers)
|
|
|
response.raise_for_status()
|
|
|
except requests.HTTPError as exc:
|
|
|
raise ValueError(f"FastAPI /text endpoint error: {exc}") from exc
|
|
|
|
|
|
|
|
|
data = response.json()
|
|
|
|
|
|
text_output = data.get("response", "")
|
|
|
|
|
|
return text_output
|
|
|
|
|
|
def generate_img(
|
|
|
self,
|
|
|
prompt: str,
|
|
|
model: str,
|
|
|
response_format: str,
|
|
|
) -> ChatMessage:
|
|
|
|
|
|
|
|
|
img_output = self._generate_image(
|
|
|
prompt=prompt,
|
|
|
model=model,
|
|
|
response_format=response_format
|
|
|
)
|
|
|
|
|
|
return ChatMessage(role="assistant", content=img_output)
|
|
|
|
|
|
|
|
|
def chat(
|
|
|
self,
|
|
|
messages: List[ChatMessage],
|
|
|
model: Optional[str] = None,
|
|
|
**kwargs
|
|
|
) -> ChatMessage:
|
|
|
"""
|
|
|
If you want to handle multi-turn chat style conversation, override this method.
|
|
|
In LlamaIndex, some indices or chat modules might call `chat(messages=...)`.
|
|
|
"""
|
|
|
|
|
|
|
|
|
conversation_log = ""
|
|
|
for m in messages:
|
|
|
role = m.role
|
|
|
content = m.content
|
|
|
if role == "user":
|
|
|
conversation_log += f"User: {content}\n"
|
|
|
else:
|
|
|
conversation_log += f"{role.capitalize()}: {content}\n"
|
|
|
|
|
|
|
|
|
|
|
|
text_output = self._generate_text(
|
|
|
prompt=conversation_log,
|
|
|
model=model,
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
|
return ChatMessage(role="assistant", content=text_output)
|
|
|
|
|
|
@llm_chat_callback()
|
|
|
def messages_to_prompt(messages):
|
|
|
prompt = ""
|
|
|
for message in messages:
|
|
|
if message.role == MessageRole.SYSTEM:
|
|
|
prompt += f"<|system|>\n(message.content)</s>\n"
|
|
|
elif message.role == MessageRole.USER:
|
|
|
prompt += f"<|user|>\n{message.content}</s>\n"
|
|
|
elif message.role == MessageRole.ASSISTANT:
|
|
|
prompt += f"<|assistant|>\n{message.content}</s>\n"
|
|
|
|
|
|
if not prompt.startswith("<|system|>\n"):
|
|
|
prompt = "<|system|>\n</s>\n" + prompt
|
|
|
|
|
|
prompt += "<|assistant|>\n"
|
|
|
return prompt
|
|
|
|
|
|
async def stream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
|
|
|
|
|
async for completion_response in self.astream_complete(self.messages_to_prompt(messages), **kwargs):
|
|
|
|
|
|
chat_response = self.convert_completion_to_chat(
|
|
|
completion_response)
|
|
|
yield chat_response
|
|
|
|
|
|
async def astream_complete(self, prompt: str, **kwargs: Any) -> AsyncGenerator[CompletionResponse, None]:
|
|
|
|
|
|
pass
|
|
|
|
|
|
def convert_completion_to_chat(self, completion_response: CompletionResponse) -> ChatResponse:
|
|
|
|
|
|
|
|
|
return ChatResponse(message=ChatMessage(role="assistant", content=completion_response.text))
|
|
|
|
|
|
@llm_chat_callback()
|
|
|
async def achat(
|
|
|
self,
|
|
|
messages: Sequence[ChatMessage],
|
|
|
**kwargs: Any,
|
|
|
) -> ChatResponse:
|
|
|
return self.chat(messages, **kwargs)
|
|
|
|
|
|
@llm_chat_callback()
|
|
|
async def astream_chat(
|
|
|
self,
|
|
|
messages: Sequence[ChatMessage],
|
|
|
**kwargs: Any,
|
|
|
) -> ChatResponseAsyncGen:
|
|
|
async def gen() -> ChatResponseAsyncGen:
|
|
|
for message in self.stream_chat(messages, **kwargs):
|
|
|
yield message
|
|
|
|
|
|
|
|
|
return gen()
|
|
|
|
|
|
@llm_completion_callback()
|
|
|
async def acomplete(
|
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
|
) -> CompletionResponse:
|
|
|
return self.complete(prompt, formatted=formatted, **kwargs)
|
|
|
|
|
|
@llm_completion_callback()
|
|
|
def complete(
|
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
|
) -> CompletionResponse:
|
|
|
return self.complete(prompt, formatted=formatted, **kwargs)
|
|
|
|
|
|
@llm_completion_callback()
|
|
|
async def astream_complete(
|
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
|
) -> CompletionResponseAsyncGen:
|
|
|
async def gen() -> CompletionResponseAsyncGen:
|
|
|
for message in self.stream_complete(prompt, formatted=formatted, **kwargs):
|
|
|
yield message
|
|
|
|
|
|
|
|
|
return gen()
|
|
|
|
|
|
@llm_completion_callback()
|
|
|
def stream_complete(
|
|
|
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
|
) -> CompletionResponseGen:
|
|
|
def gen() -> CompletionResponseGen:
|
|
|
for message in self.stream_complete(prompt, formatted=formatted, **kwargs):
|
|
|
yield message
|
|
|
return gen()
|
|
|
|
|
|
@classmethod
|
|
|
def class_name(cls) -> str:
|
|
|
return "custom_llm"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|