import uvicorn
import re
import asyncio
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
from typing import List
app = FastAPI()
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses."
class Item(BaseModel):
prompt: str
history: List[str] = []
# system_prompt: str = "You are a very powerful AI assistant."
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
def format_prompt(message, history):
prompt = ""
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response} "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(item: Item):
temperature = max(float(item.temperature), 1e-2)
# generate_kwargs = dict(
# temperature=temperature,
# max_new_tokens=item.max_new_tokens,
# top_p=float(item.top_p),
# repetition_penalty=item.repetition_penalty,
# do_sample=True,
# seed=42,
# )
formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history)
stream = client.text_generation(
formatted_prompt,
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=float(item.top_p),
repetition_penalty=item.repetition_penalty,
do_sample=True,
seed=42,
stream=True,
details=True,
return_full_text=False,
)
output = "".join(response.token.text for response in stream)
# Remove unwanted sequences or patterns (e.g., , [/INST], etc.)
output = re.sub(r"<[^>]+>", "", output) # Remove any HTML-like tags
output = re.sub(r"\s+", " ", output).strip() # Clean up extra whitespace
return output
@app.get("/generate/")
async def generate_text(
prompt: str,
history: List[str] = [],
# system_prompt: str = "You are a very powerful AI assistant.",
temperature: float = 0.0,
max_new_tokens: int = 1048,
top_p: float = 0.15,
repetition_penalty: float = 1.0,
):
item = Item(
prompt=prompt,
history=history,
# system_prompt=system_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
response = await asyncio.to_thread(generate, item)
return {"response": response}