| import { LLM, BaseLLMParams } from 'langchain/llms/base' |
|
|
| export interface CohereInput extends BaseLLMParams { |
| |
| temperature?: number |
|
|
| |
| |
| |
| maxTokens?: number |
|
|
| |
| model?: string |
|
|
| apiKey?: string |
| } |
|
|
| export class Cohere extends LLM implements CohereInput { |
| temperature = 0 |
|
|
| maxTokens = 250 |
|
|
| model: string |
|
|
| apiKey: string |
|
|
| constructor(fields?: CohereInput) { |
| super(fields ?? {}) |
|
|
| const apiKey = fields?.apiKey ?? undefined |
|
|
| if (!apiKey) { |
| throw new Error('Please set the COHERE_API_KEY environment variable or pass it to the constructor as the apiKey field.') |
| } |
|
|
| this.apiKey = apiKey |
| this.maxTokens = fields?.maxTokens ?? this.maxTokens |
| this.temperature = fields?.temperature ?? this.temperature |
| this.model = fields?.model ?? this.model |
| } |
|
|
| _llmType() { |
| return 'cohere' |
| } |
|
|
| |
| async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> { |
| const { cohere } = await Cohere.imports() |
|
|
| cohere.init(this.apiKey) |
|
|
| |
| const generateResponse = await this.caller.callWithOptions({ signal: options.signal }, cohere.generate.bind(cohere), { |
| prompt, |
| model: this.model, |
| max_tokens: this.maxTokens, |
| temperature: this.temperature, |
| end_sequences: options.stop |
| }) |
| try { |
| return generateResponse.body.generations[0].text |
| } catch { |
| throw new Error('Could not parse response.') |
| } |
| } |
|
|
| |
| static async imports(): Promise<{ |
| cohere: typeof import('cohere-ai') |
| }> { |
| try { |
| const { default: cohere } = await import('cohere-ai') |
| return { cohere } |
| } catch (e) { |
| throw new Error('Please install cohere-ai as a dependency with, e.g. `yarn add cohere-ai`') |
| } |
| } |
| } |
|
|