Spaces:
Sleeping
Sleeping
Julian Bilcke
commited on
Commit
·
d64b893
1
Parent(s):
a2c0551
working on improvements
Browse files
src/app/main.tsx
CHANGED
|
@@ -68,7 +68,7 @@ export default function Main() {
|
|
| 68 |
const newPanelsPrompts: string[] = []
|
| 69 |
const newCaptions: string[] = []
|
| 70 |
|
| 71 |
-
const nbPanelsToGenerate =
|
| 72 |
|
| 73 |
for (
|
| 74 |
let currentPanel = 0;
|
|
|
|
| 68 |
const newPanelsPrompts: string[] = []
|
| 69 |
const newCaptions: string[] = []
|
| 70 |
|
| 71 |
+
const nbPanelsToGenerate = 1
|
| 72 |
|
| 73 |
for (
|
| 74 |
let currentPanel = 0;
|
src/app/queries/getStoryContinuation.ts
CHANGED
|
@@ -7,8 +7,8 @@ export const getStoryContinuation = async ({
|
|
| 7 |
preset,
|
| 8 |
stylePrompt = "",
|
| 9 |
userStoryPrompt = "",
|
| 10 |
-
nbPanelsToGenerate =
|
| 11 |
-
nbTotalPanels =
|
| 12 |
existingPanels = [],
|
| 13 |
}: {
|
| 14 |
preset: Preset;
|
|
|
|
| 7 |
preset,
|
| 8 |
stylePrompt = "",
|
| 9 |
userStoryPrompt = "",
|
| 10 |
+
nbPanelsToGenerate = 1,
|
| 11 |
+
nbTotalPanels = 4,
|
| 12 |
existingPanels = [],
|
| 13 |
}: {
|
| 14 |
preset: Preset;
|
src/app/queries/predictNextPanels.ts
CHANGED
|
@@ -6,12 +6,13 @@ import { cleanJson } from "@/lib/cleanJson"
|
|
| 6 |
import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
|
| 7 |
import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
|
| 8 |
import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
|
|
|
|
| 9 |
|
| 10 |
export const predictNextPanels = async ({
|
| 11 |
preset,
|
| 12 |
prompt = "",
|
| 13 |
-
nbPanelsToGenerate =
|
| 14 |
-
nbTotalPanels =
|
| 15 |
existingPanels = [],
|
| 16 |
}: {
|
| 17 |
preset: Preset;
|
|
@@ -58,17 +59,26 @@ export const predictNextPanels = async ({
|
|
| 58 |
|
| 59 |
let result = ""
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
try {
|
| 62 |
// console.log(`calling predict(${query}, ${nbTotalPanels})`)
|
| 63 |
-
result = `${await predict(query,
|
| 64 |
console.log("LLM result (1st trial):", result)
|
| 65 |
if (!result.length) {
|
| 66 |
throw new Error("empty result on 1st trial!")
|
| 67 |
}
|
| 68 |
} catch (err) {
|
| 69 |
// console.log(`prediction of the story failed, trying again..`)
|
|
|
|
|
|
|
|
|
|
| 70 |
try {
|
| 71 |
-
result = `${await predict(query + " \n ",
|
| 72 |
console.log("LLM result (2nd trial):", result)
|
| 73 |
if (!result.length) {
|
| 74 |
throw new Error("empty result on 2nd trial!")
|
|
|
|
| 6 |
import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
|
| 7 |
import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
|
| 8 |
import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
|
| 9 |
+
import { sleep } from "@/lib/sleep"
|
| 10 |
|
| 11 |
export const predictNextPanels = async ({
|
| 12 |
preset,
|
| 13 |
prompt = "",
|
| 14 |
+
nbPanelsToGenerate = 1,
|
| 15 |
+
nbTotalPanels = 4,
|
| 16 |
existingPanels = [],
|
| 17 |
}: {
|
| 18 |
preset: Preset;
|
|
|
|
| 59 |
|
| 60 |
let result = ""
|
| 61 |
|
| 62 |
+
// we don't require a lot of token for our task
|
| 63 |
+
// but to be safe, let's count ~130 tokens per panel
|
| 64 |
+
const nbTokensPerPanel = 130
|
| 65 |
+
|
| 66 |
+
const nbMaxNewTokens = nbPanelsToGenerate * nbTokensPerPanel
|
| 67 |
+
|
| 68 |
try {
|
| 69 |
// console.log(`calling predict(${query}, ${nbTotalPanels})`)
|
| 70 |
+
result = `${await predict(query, nbMaxNewTokens)}`.trim()
|
| 71 |
console.log("LLM result (1st trial):", result)
|
| 72 |
if (!result.length) {
|
| 73 |
throw new Error("empty result on 1st trial!")
|
| 74 |
}
|
| 75 |
} catch (err) {
|
| 76 |
// console.log(`prediction of the story failed, trying again..`)
|
| 77 |
+
// this should help throttle things on a bit on the LLM API side
|
| 78 |
+
await sleep(2000)
|
| 79 |
+
|
| 80 |
try {
|
| 81 |
+
result = `${await predict(query + " \n ", nbMaxNewTokens)}`.trim()
|
| 82 |
console.log("LLM result (2nd trial):", result)
|
| 83 |
if (!result.length) {
|
| 84 |
throw new Error("empty result on 2nd trial!")
|
src/app/queries/predictWithGroq.ts
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
import Groq from "groq-sdk"
|
| 4 |
|
| 5 |
-
export async function predict(inputs: string,
|
| 6 |
const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
|
| 7 |
const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
|
| 8 |
|
|
@@ -18,6 +18,9 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
|
|
| 18 |
const res = await groq.chat.completions.create({
|
| 19 |
messages: messages,
|
| 20 |
model: groqApiModel,
|
|
|
|
|
|
|
|
|
|
| 21 |
})
|
| 22 |
|
| 23 |
return res.choices[0].message.content || ""
|
|
|
|
| 2 |
|
| 3 |
import Groq from "groq-sdk"
|
| 4 |
|
| 5 |
+
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
|
| 6 |
const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
|
| 7 |
const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
|
| 8 |
|
|
|
|
| 18 |
const res = await groq.chat.completions.create({
|
| 19 |
messages: messages,
|
| 20 |
model: groqApiModel,
|
| 21 |
+
stream: false,
|
| 22 |
+
temperature: 0.5,
|
| 23 |
+
max_tokens: nbMaxNewTokens,
|
| 24 |
})
|
| 25 |
|
| 26 |
return res.choices[0].message.content || ""
|
src/app/queries/predictWithHuggingFace.ts
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
|
| 4 |
import { LLMEngine } from "@/types"
|
| 5 |
|
| 6 |
-
export async function predict(inputs: string,
|
| 7 |
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
|
| 8 |
|
| 9 |
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
|
|
@@ -12,10 +12,6 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
|
|
| 12 |
|
| 13 |
let hfie: HfInferenceEndpoint = hf
|
| 14 |
|
| 15 |
-
// we don't require a lot of token for our task
|
| 16 |
-
// but to be safe, let's count ~110 tokens per panel
|
| 17 |
-
const nbMaxNewTokens = nbPanels * 130 // 110 isn't enough anymore for long dialogues
|
| 18 |
-
|
| 19 |
switch (llmEngine) {
|
| 20 |
case "INFERENCE_ENDPOINT":
|
| 21 |
if (inferenceEndpoint) {
|
|
|
|
| 3 |
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
|
| 4 |
import { LLMEngine } from "@/types"
|
| 5 |
|
| 6 |
+
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
|
| 7 |
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
|
| 8 |
|
| 9 |
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
|
|
|
|
| 12 |
|
| 13 |
let hfie: HfInferenceEndpoint = hf
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
switch (llmEngine) {
|
| 16 |
case "INFERENCE_ENDPOINT":
|
| 17 |
if (inferenceEndpoint) {
|
src/app/queries/predictWithOpenAI.ts
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
import type { ChatCompletionMessage } from "openai/resources/chat"
|
| 4 |
import OpenAI from "openai"
|
| 5 |
|
| 6 |
-
export async function predict(inputs: string,
|
| 7 |
const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
|
| 8 |
const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
|
| 9 |
const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
|
|
@@ -23,6 +23,8 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
|
|
| 23 |
stream: false,
|
| 24 |
model: openaiApiModel,
|
| 25 |
temperature: 0.8,
|
|
|
|
|
|
|
| 26 |
// TODO: use the nbPanels to define a max token limit
|
| 27 |
})
|
| 28 |
|
|
|
|
| 3 |
import type { ChatCompletionMessage } from "openai/resources/chat"
|
| 4 |
import OpenAI from "openai"
|
| 5 |
|
| 6 |
+
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
|
| 7 |
const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
|
| 8 |
const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
|
| 9 |
const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
|
|
|
|
| 23 |
stream: false,
|
| 24 |
model: openaiApiModel,
|
| 25 |
temperature: 0.8,
|
| 26 |
+
max_tokens: nbMaxNewTokens,
|
| 27 |
+
|
| 28 |
// TODO: use the nbPanels to define a max token limit
|
| 29 |
})
|
| 30 |
|