File size: 2,054 Bytes
5fa76ab
 
 
 
d5e435e
5fa76ab
 
 
 
d82c856
eaea91c
 
 
 
5fa76ab
 
 
ad724cd
5fa76ab
cf4787c
f35c389
cf4787c
727f564
5fa76ab
ad724cd
1973306
ad724cd
 
de504a4
1973306
ad724cd
5fa76ab
c42519d
 
 
 
5fa76ab
 
 
 
53d3a11
5fa76ab
 
 
 
53d3a11
2b8372f
 
 
5fa76ab
 
ad724cd
eaea91c
 
 
 
 
5fa76ab
eaea91c
 
5fa76ab
 
 
 
2b8372f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


app = FastAPI()

model_id = "mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

#client = InferenceClient("HFHAB/FinetunedMistralModel")

class Item(BaseModel):
    prompt: str
    history: list
    system_prompt: str
    temperature: float = 0.3
    max_new_tokens: int = 5000
    top_p: float = 0.15
    repetition_penalty: float = 1.0

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response} "
    prompt += f"</s>[INST] {message} [/INST]"
    return prompt

def formatting_func(example):
    text = f"### Question: {example['input']}\n ### Answer: {example['output']}"
    return text

def generate(item: Item):
    temperature = float(item.temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(item.top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=item.max_new_tokens,
        top_p=top_p,
        repetition_penalty=item.repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(0)
    out = model.generate(**inputs, max_new_tokens=250, temperature = 0.6, top_p=0.95, tok_k=40)
    output = tokenizer.decode(out[0], skip_special_tokens=True)
    #stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    #output = ""

    #for response in stream:
    #    output += response.token.text
    return output

@app.post("/generate/")
async def generate_text(item: Item):
    return {"response": generate(item)}