Spaces:
Runtime error
Runtime error
| from flask import Flask, request, jsonify | |
| from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList | |
| from threading import Thread | |
| import textwrap | |
| import torch | |
| import re | |
| app = Flask(__name__) | |
| # Assuming tokenizer and model are loaded | |
| # tokenizer = ... | |
| # model = load_model("/path/to/your/model.pth") | |
| max_print_width = 150 | |
| class RepetitionStoppingCriteria(StoppingCriteria): | |
| def __init__(self, tokenizer, window_size=10, threshold=3): | |
| self.tokenizer = tokenizer | |
| self.window_size = window_size | |
| self.threshold = threshold | |
| self.generated_text = "" | |
| def __call__(self, input_ids, scores, **kwargs): | |
| new_text = self.tokenizer.decode(input_ids[0][-1:]) | |
| self.generated_text += new_text | |
| if len(self.generated_text) >= self.window_size * self.threshold: | |
| recent_text = self.generated_text[-self.window_size * self.threshold:] | |
| for i in range(1, self.window_size + 1): | |
| pattern = '(.{' + str(i) + '})\\1{' + str(self.threshold - 1) + ',}' | |
| if re.search(pattern, recent_text): | |
| return True | |
| return False | |
| def generate_story(prompt, tone="dark", genre="fantasy", max_tokens=128): | |
| enhanced_prompt = f"[Tone: {tone}] [Genre: {genre}]\n{prompt}" | |
| inputs = tokenizer([enhanced_prompt], return_tensors="pt").to("cuda") | |
| text_streamer = TextIteratorStreamer(tokenizer) | |
| repetition_checker = RepetitionStoppingCriteria(tokenizer) | |
| stopping_criteria = StoppingCriteriaList([repetition_checker]) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=text_streamer, | |
| max_new_tokens=max_tokens, | |
| use_cache=True, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.92, | |
| top_k=50, | |
| repetition_penalty=1.2, | |
| stopping_criteria=stopping_criteria, | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| full_text = "" | |
| for new_text in text_streamer: | |
| full_text += new_text | |
| return full_text | |
| def generate(): | |
| data = request.get_json() | |
| prompt = data.get("prompt", "") | |
| tone = data.get("tone", "dark") | |
| genre = data.get("genre", "fantasy") | |
| max_tokens = data.get("max_tokens", 128) | |
| story = generate_story(prompt, tone, genre, max_tokens) | |
| return jsonify({"story": story}) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=5000) | |