File size: 2,555 Bytes
417b6c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from flask import Flask, request, jsonify
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
from threading import Thread
import textwrap
import torch
import re

app = Flask(__name__)

# Assuming tokenizer and model are loaded
# tokenizer = ...
# model = load_model("/path/to/your/model.pth")

max_print_width = 150

class RepetitionStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer, window_size=10, threshold=3):
        self.tokenizer = tokenizer
        self.window_size = window_size
        self.threshold = threshold
        self.generated_text = ""

    def __call__(self, input_ids, scores, **kwargs):
        new_text = self.tokenizer.decode(input_ids[0][-1:])
        self.generated_text += new_text

        if len(self.generated_text) >= self.window_size * self.threshold:
            recent_text = self.generated_text[-self.window_size * self.threshold:]
            for i in range(1, self.window_size + 1):
                pattern = '(.{' + str(i) + '})\\1{' + str(self.threshold - 1) + ',}'
                if re.search(pattern, recent_text):
                    return True
        return False

def generate_story(prompt, tone="dark", genre="fantasy", max_tokens=128):
    enhanced_prompt = f"[Tone: {tone}] [Genre: {genre}]\n{prompt}"
    inputs = tokenizer([enhanced_prompt], return_tensors="pt").to("cuda")
    text_streamer = TextIteratorStreamer(tokenizer)
    repetition_checker = RepetitionStoppingCriteria(tokenizer)
    stopping_criteria = StoppingCriteriaList([repetition_checker])

    generation_kwargs = dict(
        inputs,
        streamer=text_streamer,
        max_new_tokens=max_tokens,
        use_cache=True,
        do_sample=True,
        temperature=0.7,
        top_p=0.92,
        top_k=50,
        repetition_penalty=1.2,
        stopping_criteria=stopping_criteria,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    full_text = ""
    for new_text in text_streamer:
        full_text += new_text

    return full_text

@app.route("/generate", methods=["POST"])
def generate():
    data = request.get_json()
    prompt = data.get("prompt", "")
    tone = data.get("tone", "dark")
    genre = data.get("genre", "fantasy")
    max_tokens = data.get("max_tokens", 128)

    story = generate_story(prompt, tone, genre, max_tokens)
    return jsonify({"story": story})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)