God_HP / app.py
jasperBHOS's picture
Upload app.py
417b6c4 verified
from flask import Flask, request, jsonify
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
from threading import Thread
import textwrap
import torch
import re
app = Flask(__name__)
# Assuming tokenizer and model are loaded
# tokenizer = ...
# model = load_model("/path/to/your/model.pth")
max_print_width = 150
class RepetitionStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, window_size=10, threshold=3):
self.tokenizer = tokenizer
self.window_size = window_size
self.threshold = threshold
self.generated_text = ""
def __call__(self, input_ids, scores, **kwargs):
new_text = self.tokenizer.decode(input_ids[0][-1:])
self.generated_text += new_text
if len(self.generated_text) >= self.window_size * self.threshold:
recent_text = self.generated_text[-self.window_size * self.threshold:]
for i in range(1, self.window_size + 1):
pattern = '(.{' + str(i) + '})\\1{' + str(self.threshold - 1) + ',}'
if re.search(pattern, recent_text):
return True
return False
def generate_story(prompt, tone="dark", genre="fantasy", max_tokens=128):
enhanced_prompt = f"[Tone: {tone}] [Genre: {genre}]\n{prompt}"
inputs = tokenizer([enhanced_prompt], return_tensors="pt").to("cuda")
text_streamer = TextIteratorStreamer(tokenizer)
repetition_checker = RepetitionStoppingCriteria(tokenizer)
stopping_criteria = StoppingCriteriaList([repetition_checker])
generation_kwargs = dict(
inputs,
streamer=text_streamer,
max_new_tokens=max_tokens,
use_cache=True,
do_sample=True,
temperature=0.7,
top_p=0.92,
top_k=50,
repetition_penalty=1.2,
stopping_criteria=stopping_criteria,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_text = ""
for new_text in text_streamer:
full_text += new_text
return full_text
@app.route("/generate", methods=["POST"])
def generate():
data = request.get_json()
prompt = data.get("prompt", "")
tone = data.get("tone", "dark")
genre = data.get("genre", "fantasy")
max_tokens = data.get("max_tokens", 128)
story = generate_story(prompt, tone, genre, max_tokens)
return jsonify({"story": story})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)