jasperBHOS commited on
Commit
417b6c4
·
verified ·
1 Parent(s): c2d9187

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
3
+ from threading import Thread
4
+ import textwrap
5
+ import torch
6
+ import re
7
+
8
+ app = Flask(__name__)
9
+
10
+ # Assuming tokenizer and model are loaded
11
+ # tokenizer = ...
12
+ # model = load_model("/path/to/your/model.pth")
13
+
14
+ max_print_width = 150
15
+
16
+ class RepetitionStoppingCriteria(StoppingCriteria):
17
+ def __init__(self, tokenizer, window_size=10, threshold=3):
18
+ self.tokenizer = tokenizer
19
+ self.window_size = window_size
20
+ self.threshold = threshold
21
+ self.generated_text = ""
22
+
23
+ def __call__(self, input_ids, scores, **kwargs):
24
+ new_text = self.tokenizer.decode(input_ids[0][-1:])
25
+ self.generated_text += new_text
26
+
27
+ if len(self.generated_text) >= self.window_size * self.threshold:
28
+ recent_text = self.generated_text[-self.window_size * self.threshold:]
29
+ for i in range(1, self.window_size + 1):
30
+ pattern = '(.{' + str(i) + '})\\1{' + str(self.threshold - 1) + ',}'
31
+ if re.search(pattern, recent_text):
32
+ return True
33
+ return False
34
+
35
+ def generate_story(prompt, tone="dark", genre="fantasy", max_tokens=128):
36
+ enhanced_prompt = f"[Tone: {tone}] [Genre: {genre}]\n{prompt}"
37
+ inputs = tokenizer([enhanced_prompt], return_tensors="pt").to("cuda")
38
+ text_streamer = TextIteratorStreamer(tokenizer)
39
+ repetition_checker = RepetitionStoppingCriteria(tokenizer)
40
+ stopping_criteria = StoppingCriteriaList([repetition_checker])
41
+
42
+ generation_kwargs = dict(
43
+ inputs,
44
+ streamer=text_streamer,
45
+ max_new_tokens=max_tokens,
46
+ use_cache=True,
47
+ do_sample=True,
48
+ temperature=0.7,
49
+ top_p=0.92,
50
+ top_k=50,
51
+ repetition_penalty=1.2,
52
+ stopping_criteria=stopping_criteria,
53
+ )
54
+
55
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
56
+ thread.start()
57
+
58
+ full_text = ""
59
+ for new_text in text_streamer:
60
+ full_text += new_text
61
+
62
+ return full_text
63
+
64
+ @app.route("/generate", methods=["POST"])
65
+ def generate():
66
+ data = request.get_json()
67
+ prompt = data.get("prompt", "")
68
+ tone = data.get("tone", "dark")
69
+ genre = data.get("genre", "fantasy")
70
+ max_tokens = data.get("max_tokens", 128)
71
+
72
+ story = generate_story(prompt, tone, genre, max_tokens)
73
+ return jsonify({"story": story})
74
+
75
+ if __name__ == "__main__":
76
+ app.run(host="0.0.0.0", port=5000)