minoD commited on
Commit
560c76a
·
verified ·
1 Parent(s): 68622cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -9
app.py CHANGED
@@ -4,21 +4,22 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
5
  import os
6
 
7
- # bitsandbytesを無効化
8
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
9
 
10
  model_name = "minoD/JURAN"
11
 
12
- # モデルのロード(CPUで、bitsandbytesを使わない)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
  device_map="cpu",
16
  torch_dtype=torch.float16,
17
- low_cpu_mem_usage=True, # メモリ効率を改善
18
  )
19
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
20
 
21
- # プロンプトテンプレートの準備
 
 
22
  def generate_prompt(F):
23
  result = f"""### 指示:
24
  あなたは企業の面接官です。以下の就活生のエントリーシート内容を読んで、深掘りする質問を1つ考えてください。
@@ -30,10 +31,38 @@ def generate_prompt(F):
30
  result = result.replace('\n', '<NL>')
31
  return result
32
 
33
- # テキスト生成関数の定義
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @spaces.GPU(duration=60)
35
  def generate2(F=None, maxTokens=256):
36
  try:
 
 
 
 
 
 
 
37
  model.to("cuda")
38
 
39
  prompt = generate_prompt(F)
@@ -56,17 +85,14 @@ def generate2(F=None, maxTokens=256):
56
  outputs = outputs[0].tolist()
57
  decoded = tokenizer.decode(outputs)
58
 
59
- # EOSトークンにヒットしたらデコード完了
60
  if tokenizer.eos_token_id in outputs:
61
  eos_index = outputs.index(tokenizer.eos_token_id)
62
  decoded = tokenizer.decode(outputs[:eos_index])
63
 
64
- # レスポンス内容のみ抽出(修正)
65
  sentinel = "### 面接官の質問:"
66
  sentinelLoc = decoded.find(sentinel)
67
  if sentinelLoc >= 0:
68
  result = decoded[sentinelLoc + len(sentinel):]
69
- # 最初の改行までを取得(1つの質問だけ)
70
  result = result.split('\n')[0] if '\n' in result else result
71
  return result.replace("<NL>", "\n").strip()
72
  else:
@@ -78,7 +104,6 @@ def generate2(F=None, maxTokens=256):
78
  def inference(input_text):
79
  return generate2(input_text)
80
 
81
- # Gradioインターフェース
82
  iface = gr.Interface(
83
  fn=inference,
84
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),
 
4
  import spaces
5
  import os
6
 
 
7
  os.environ["BITSANDBYTES_NOWELCOME"] = "1"
8
 
9
  model_name = "minoD/JURAN"
10
 
11
+ # モデルのロード
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_name,
14
  device_map="cpu",
15
  torch_dtype=torch.float16,
16
+ low_cpu_mem_usage=True,
17
  )
18
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19
 
20
+ # ウォームアップフラグ
21
+ warmup_done = False
22
+
23
  def generate_prompt(F):
24
  result = f"""### 指示:
25
  あなたは企業の面接官です。以下の就活生のエントリーシート内容を読んで、深掘りする質問を1つ考えてください。
 
31
  result = result.replace('\n', '<NL>')
32
  return result
33
 
34
+ @spaces.GPU(duration=60)
35
+ def warmup_model():
36
+ """モデルのウォームアップ処理"""
37
+ global warmup_done
38
+ if not warmup_done:
39
+ print("ウォームアップ中...")
40
+ model.to("cuda")
41
+
42
+ # ダミー推論を実行
43
+ dummy_input = tokenizer("テスト", return_tensors="pt").input_ids.to("cuda")
44
+ with torch.no_grad():
45
+ _ = model.generate(
46
+ dummy_input,
47
+ max_new_tokens=10,
48
+ do_sample=False
49
+ )
50
+
51
+ model.to("cpu")
52
+ torch.cuda.empty_cache()
53
+ warmup_done = True
54
+ print("ウォームアップ完了")
55
+
56
  @spaces.GPU(duration=60)
57
  def generate2(F=None, maxTokens=256):
58
  try:
59
+ # ウォームアップ(初回のみ)
60
+ if not warmup_done:
61
+ warmup_model()
62
+
63
+ # 乱数シードを固定(オプション)
64
+ torch.manual_seed(42)
65
+
66
  model.to("cuda")
67
 
68
  prompt = generate_prompt(F)
 
85
  outputs = outputs[0].tolist()
86
  decoded = tokenizer.decode(outputs)
87
 
 
88
  if tokenizer.eos_token_id in outputs:
89
  eos_index = outputs.index(tokenizer.eos_token_id)
90
  decoded = tokenizer.decode(outputs[:eos_index])
91
 
 
92
  sentinel = "### 面接官の質問:"
93
  sentinelLoc = decoded.find(sentinel)
94
  if sentinelLoc >= 0:
95
  result = decoded[sentinelLoc + len(sentinel):]
 
96
  result = result.split('\n')[0] if '\n' in result else result
97
  return result.replace("<NL>", "\n").strip()
98
  else:
 
104
  def inference(input_text):
105
  return generate2(input_text)
106
 
 
107
  iface = gr.Interface(
108
  fn=inference,
109
  inputs=gr.Textbox(lines=5, label="学生時代に打ち込んだこと、研究、ESを入力", placeholder="半導体の研究に打ち込んだ"),