import uvicorn import re import asyncio from fastapi import FastAPI from pydantic import BaseModel from huggingface_hub import InferenceClient from typing import List app = FastAPI() client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") SYSTEM_PROMPT = "You are a very powerful AI to generate interesting stories for short-form content consumption. Make sure to hook the readers attention in the first few seconds. Make sure to be engaging and creative in your responses." class Item(BaseModel): prompt: str history: List[str] = [] # system_prompt: str = "You are a very powerful AI assistant." temperature: float = 0.0 max_new_tokens: int = 1048 top_p: float = 0.15 repetition_penalty: float = 1.0 def format_prompt(message, history): prompt = "" for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt def generate(item: Item): temperature = max(float(item.temperature), 1e-2) # generate_kwargs = dict( # temperature=temperature, # max_new_tokens=item.max_new_tokens, # top_p=float(item.top_p), # repetition_penalty=item.repetition_penalty, # do_sample=True, # seed=42, # ) formatted_prompt = format_prompt(f"{SYSTEM_PROMPT}, {item.prompt}", item.history) stream = client.text_generation( formatted_prompt, temperature=temperature, max_new_tokens=item.max_new_tokens, top_p=float(item.top_p), repetition_penalty=item.repetition_penalty, do_sample=True, seed=42, stream=True, details=True, return_full_text=False, ) output = "".join(response.token.text for response in stream) # Remove unwanted sequences or patterns (e.g., , [/INST], etc.) output = re.sub(r"<[^>]+>", "", output) # Remove any HTML-like tags output = re.sub(r"\s+", " ", output).strip() # Clean up extra whitespace return output @app.get("/generate/") async def generate_text( prompt: str, history: List[str] = [], # system_prompt: str = "You are a very powerful AI assistant.", temperature: float = 0.0, max_new_tokens: int = 1048, top_p: float = 0.15, repetition_penalty: float = 1.0, ): item = Item( prompt=prompt, history=history, # system_prompt=system_prompt, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, ) response = await asyncio.to_thread(generate, item) return {"response": response}