from flask import Flask, request, jsonify from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList from threading import Thread import textwrap import torch import re app = Flask(__name__) # Assuming tokenizer and model are loaded # tokenizer = ... # model = load_model("/path/to/your/model.pth") max_print_width = 150 class RepetitionStoppingCriteria(StoppingCriteria): def __init__(self, tokenizer, window_size=10, threshold=3): self.tokenizer = tokenizer self.window_size = window_size self.threshold = threshold self.generated_text = "" def __call__(self, input_ids, scores, **kwargs): new_text = self.tokenizer.decode(input_ids[0][-1:]) self.generated_text += new_text if len(self.generated_text) >= self.window_size * self.threshold: recent_text = self.generated_text[-self.window_size * self.threshold:] for i in range(1, self.window_size + 1): pattern = '(.{' + str(i) + '})\\1{' + str(self.threshold - 1) + ',}' if re.search(pattern, recent_text): return True return False def generate_story(prompt, tone="dark", genre="fantasy", max_tokens=128): enhanced_prompt = f"[Tone: {tone}] [Genre: {genre}]\n{prompt}" inputs = tokenizer([enhanced_prompt], return_tensors="pt").to("cuda") text_streamer = TextIteratorStreamer(tokenizer) repetition_checker = RepetitionStoppingCriteria(tokenizer) stopping_criteria = StoppingCriteriaList([repetition_checker]) generation_kwargs = dict( inputs, streamer=text_streamer, max_new_tokens=max_tokens, use_cache=True, do_sample=True, temperature=0.7, top_p=0.92, top_k=50, repetition_penalty=1.2, stopping_criteria=stopping_criteria, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() full_text = "" for new_text in text_streamer: full_text += new_text return full_text @app.route("/generate", methods=["POST"]) def generate(): data = request.get_json() prompt = data.get("prompt", "") tone = data.get("tone", "dark") genre = data.get("genre", "fantasy") max_tokens = data.get("max_tokens", 128) story = generate_story(prompt, tone, genre, max_tokens) return jsonify({"story": story}) if __name__ == "__main__": app.run(host="0.0.0.0", port=5000)