File size: 2,705 Bytes
3d4ff6b
b3b67a2
 
5fa76ab
3d4ff6b
5fa76ab
de9f38e
72d1491
5fa76ab
 
3d4ff6b
5fa76ab
4f3e7e2
a928c98
13c78d8
5fa76ab
 
de9f38e
a794a46
3d4ff6b
 
 
 
 
5fa76ab
 
3d4ff6b
72c4724
3d4ff6b
 
 
5fa76ab
 
3d4ff6b
5fa76ab
de9f38e
3d4ff6b
 
 
 
 
 
 
 
5fa76ab
3d4ff6b
 
a928c98
4bba0ce
3d4ff6b
ce13422
 
 
 
3d4ff6b
 
 
a928c98
3d4ff6b
 
 
 
8d2ac5a
5fa76ab
 
3d4ff6b
de9f38e
 
 
 
3d4ff6b
 
 
 
 
de9f38e
 
 
 
3d4ff6b
de9f38e
 
 
a928c98
de9f38e
a928c98
4f3e7e2
b3b67a2
72c4724
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import uvicorn
import re
import asyncio
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from typing import List


app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."


class Item(BaseModel):
    prompt: str
    history: List[str] = []
    #  system_prompt: str = "You are a very powerful AI assistant."
    temperature: float = 0.0
    max_new_tokens: int = 1048
    top_p: float = 0.15
    repetition_penalty: float = 1.0


def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt


def generate(item: Item):
    temperature = max(float(item.temperature), 1e-2)
    # generate_kwargs = dict(
    #     temperature=temperature,
    #     max_new_tokens=item.max_new_tokens,
    #     top_p=float(item.top_p),
    #     repetition_penalty=item.repetition_penalty,
    #     do_sample=True,
    #     seed=42,
    # )

    formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
    stream = client.text_generation(
        formatted_prompt,
        temperature=temperature,
        max_new_tokens=item.max_new_tokens,
        top_p=float(item.top_p),
        repetition_penalty=item.repetition_penalty,
        do_sample=True,
        seed=42,
        stream=True,
        details=True,
        return_full_text=False,
    )
    output = "".join(response.token.text for response in stream)
    # Remove unwanted sequences or patterns (e.g., <s>, [/INST], etc.)
    output = re.sub(r"<[^>]+>", "", output)  # Remove any HTML-like tags
    output = re.sub(r"\s+", " ", output).strip()  # Clean up extra whitespace

    return output


@app.get("/generate/")
async def generate_text(
    prompt: str,
    history: List[str] = [],
    # system_prompt: str = "You are a very powerful AI assistant.",
    temperature: float = 0.0,
    max_new_tokens: int = 1048,
    top_p: float = 0.15,
    repetition_penalty: float = 1.0,
):
    item = Item(
        prompt=prompt,
        history=history,
        # system_prompt=system_prompt,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
    )

    response = await asyncio.to_thread(generate, item)

    return {"response": response}