| | package langchain |
| |
|
| | import ( |
| | "context" |
| | "fmt" |
| |
|
| | "github.com/tmc/langchaingo/llms" |
| | "github.com/tmc/langchaingo/llms/huggingface" |
| | ) |
| |
|
| | type HuggingFace struct { |
| | modelPath string |
| | token string |
| | } |
| |
|
| | func NewHuggingFace(repoId, token string) (*HuggingFace, error) { |
| | if token == "" { |
| | return nil, fmt.Errorf("no huggingface token provided") |
| | } |
| | return &HuggingFace{ |
| | modelPath: repoId, |
| | token: token, |
| | }, nil |
| | } |
| |
|
| | func (s *HuggingFace) PredictHuggingFace(text string, opts ...PredictOption) (*Predict, error) { |
| | po := NewPredictOptions(opts...) |
| |
|
| | |
| | llm, err := huggingface.New(huggingface.WithToken(s.token)) |
| | if err != nil { |
| | return nil, err |
| | } |
| |
|
| | |
| | co := []llms.CallOption{ |
| | llms.WithModel(po.Model), |
| | llms.WithMaxTokens(po.MaxTokens), |
| | llms.WithTemperature(po.Temperature), |
| | llms.WithStopWords(po.StopWords), |
| | } |
| |
|
| | |
| | ctx := context.Background() |
| | completion, err := llm.Call(ctx, text, co...) |
| | if err != nil { |
| | return nil, err |
| | } |
| |
|
| | return &Predict{ |
| | Completion: completion, |
| | }, nil |
| | } |
| |
|