File size: 3,282 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package openai

import (
	"encoding/json"

	"github.com/mudler/LocalAI/core/backend"
	"github.com/mudler/LocalAI/core/config"

	"github.com/mudler/LocalAI/core/schema"
	model "github.com/mudler/LocalAI/pkg/model"
)

func ComputeChoices(
	req *schema.OpenAIRequest,
	predInput string,
	config *config.ModelConfig,
	bcl *config.ModelConfigLoader,
	o *config.ApplicationConfig,
	loader *model.ModelLoader,
	cb func(string, *[]schema.Choice),
	tokenCallback func(string, backend.TokenUsage) bool) ([]schema.Choice, backend.TokenUsage, error) {
	n := req.N // number of completions to return
	result := []schema.Choice{}

	if n == 0 {
		n = 1
	}

	images := []string{}
	for _, m := range req.Messages {
		images = append(images, m.StringImages...)
	}
	videos := []string{}
	for _, m := range req.Messages {
		videos = append(videos, m.StringVideos...)
	}
	audios := []string{}
	for _, m := range req.Messages {
		audios = append(audios, m.StringAudios...)
	}

	// Serialize tools and tool_choice to JSON strings
	toolsJSON := ""
	if len(req.Tools) > 0 {
		toolsBytes, err := json.Marshal(req.Tools)
		if err == nil {
			toolsJSON = string(toolsBytes)
		}
	}
	toolChoiceJSON := ""
	if req.ToolsChoice != nil {
		toolChoiceBytes, err := json.Marshal(req.ToolsChoice)
		if err == nil {
			toolChoiceJSON = string(toolChoiceBytes)
		}
	}

	// Extract logprobs from request
	// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
	var logprobs *int
	var topLogprobs *int
	if req.Logprobs.IsEnabled() {
		// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
		if req.TopLogprobs != nil {
			topLogprobs = req.TopLogprobs
			// For backend compatibility, set logprobs to the top_logprobs value
			logprobs = req.TopLogprobs
		} else {
			// Default to 1 if logprobs is true but top_logprobs not specified
			val := 1
			logprobs = &val
			topLogprobs = &val
		}
	}

	// Extract logit_bias from request
	// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
	var logitBias map[string]float64
	if len(req.LogitBias) > 0 {
		logitBias = req.LogitBias
	}

	// get the model function to call for the result
	predFunc, err := backend.ModelInference(
		req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
	if err != nil {
		return result, backend.TokenUsage{}, err
	}

	tokenUsage := backend.TokenUsage{}

	for i := 0; i < n; i++ {
		prediction, err := predFunc()
		if err != nil {
			return result, backend.TokenUsage{}, err
		}

		tokenUsage.Prompt += prediction.Usage.Prompt
		tokenUsage.Completion += prediction.Usage.Completion
		tokenUsage.TimingPromptProcessing += prediction.Usage.TimingPromptProcessing
		tokenUsage.TimingTokenGeneration += prediction.Usage.TimingTokenGeneration

		finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
		cb(finetunedResponse, &result)

		// Add logprobs to the last choice if present
		if prediction.Logprobs != nil && len(result) > 0 {
			result[len(result)-1].Logprobs = prediction.Logprobs
		}

		//result = append(result, Choice{Text: prediction})

	}
	return result, tokenUsage, err
}