Spaces:
Runtime error
Runtime error
File size: 2,555 Bytes
417b6c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from flask import Flask, request, jsonify
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
from threading import Thread
import textwrap
import torch
import re
app = Flask(__name__)
# Assuming tokenizer and model are loaded
# tokenizer = ...
# model = load_model("/path/to/your/model.pth")
max_print_width = 150
class RepetitionStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer, window_size=10, threshold=3):
self.tokenizer = tokenizer
self.window_size = window_size
self.threshold = threshold
self.generated_text = ""
def __call__(self, input_ids, scores, **kwargs):
new_text = self.tokenizer.decode(input_ids[0][-1:])
self.generated_text += new_text
if len(self.generated_text) >= self.window_size * self.threshold:
recent_text = self.generated_text[-self.window_size * self.threshold:]
for i in range(1, self.window_size + 1):
pattern = '(.{' + str(i) + '})\\1{' + str(self.threshold - 1) + ',}'
if re.search(pattern, recent_text):
return True
return False
def generate_story(prompt, tone="dark", genre="fantasy", max_tokens=128):
enhanced_prompt = f"[Tone: {tone}] [Genre: {genre}]\n{prompt}"
inputs = tokenizer([enhanced_prompt], return_tensors="pt").to("cuda")
text_streamer = TextIteratorStreamer(tokenizer)
repetition_checker = RepetitionStoppingCriteria(tokenizer)
stopping_criteria = StoppingCriteriaList([repetition_checker])
generation_kwargs = dict(
inputs,
streamer=text_streamer,
max_new_tokens=max_tokens,
use_cache=True,
do_sample=True,
temperature=0.7,
top_p=0.92,
top_k=50,
repetition_penalty=1.2,
stopping_criteria=stopping_criteria,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
full_text = ""
for new_text in text_streamer:
full_text += new_text
return full_text
@app.route("/generate", methods=["POST"])
def generate():
data = request.get_json()
prompt = data.get("prompt", "")
tone = data.get("tone", "dark")
genre = data.get("genre", "fantasy")
max_tokens = data.get("max_tokens", 128)
story = generate_story(prompt, tone, genre, max_tokens)
return jsonify({"story": story})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
|