minoD commited on
Commit
68622cd
·
verified ·
1 Parent(s): a6aaa5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -20,22 +20,26 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
20
 
21
  # プロンプトテンプレートの準備
22
  def generate_prompt(F):
23
- result = f"""### 指示:あなたは企業の面接官です.就活生のエントリーシートを元に質問を行ってください.### 質問:{F}### 回答:"""
 
 
 
 
 
 
24
  result = result.replace('\n', '<NL>')
25
  return result
26
 
27
  # テキスト生成関数の定義
28
- @spaces.GPU(duration=60) # タイムアウトを60秒に設定
29
  def generate2(F=None, maxTokens=256):
30
  try:
31
- # モデルをGPUに転送
32
  model.to("cuda")
33
 
34
- # 推論
35
  prompt = generate_prompt(F)
36
  input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
37
 
38
- with torch.no_grad(): # 勾配計算を無効化してメモリ節約
39
  outputs = model.generate(
40
  input_ids=input_ids,
41
  max_new_tokens=maxTokens,
@@ -46,9 +50,8 @@ def generate2(F=None, maxTokens=256):
46
  no_repeat_ngram_size=2,
47
  )
48
 
49
- # CPUに戻す
50
  model.to("cpu")
51
- torch.cuda.empty_cache() # GPUメモリをクリア
52
 
53
  outputs = outputs[0].tolist()
54
  decoded = tokenizer.decode(outputs)
@@ -58,12 +61,14 @@ def generate2(F=None, maxTokens=256):
58
  eos_index = outputs.index(tokenizer.eos_token_id)
59
  decoded = tokenizer.decode(outputs[:eos_index])
60
 
61
- # レスポンス内容のみ抽出
62
- sentinel = "### 回答:"
63
  sentinelLoc = decoded.find(sentinel)
64
  if sentinelLoc >= 0:
65
  result = decoded[sentinelLoc + len(sentinel):]
66
- return result.replace("<NL>", "\n")
 
 
67
  else:
68
  return 'Warning: Expected prompt template to be emitted. Ignoring output.'
69
 
 
20
 
21
  # プロンプトテンプレートの準備
22
  def generate_prompt(F):
23
+ result = f"""### 指示:
24
+ あなたは企業の面接官です。以下の就活生のエントリーシート内容を読んで、深掘りする質問を1つ考えてください。
25
+
26
+ ### エントリーシート:
27
+ {F}
28
+
29
+ ### 面接官の質問:"""
30
  result = result.replace('\n', '<NL>')
31
  return result
32
 
33
  # テキスト生成関数の定義
34
+ @spaces.GPU(duration=60)
35
  def generate2(F=None, maxTokens=256):
36
  try:
 
37
  model.to("cuda")
38
 
 
39
  prompt = generate_prompt(F)
40
  input_ids = tokenizer(prompt, return_tensors="pt", truncation=True, add_special_tokens=False).input_ids.to("cuda")
41
 
42
+ with torch.no_grad():
43
  outputs = model.generate(
44
  input_ids=input_ids,
45
  max_new_tokens=maxTokens,
 
50
  no_repeat_ngram_size=2,
51
  )
52
 
 
53
  model.to("cpu")
54
+ torch.cuda.empty_cache()
55
 
56
  outputs = outputs[0].tolist()
57
  decoded = tokenizer.decode(outputs)
 
61
  eos_index = outputs.index(tokenizer.eos_token_id)
62
  decoded = tokenizer.decode(outputs[:eos_index])
63
 
64
+ # レスポンス内容のみ抽出(修正)
65
+ sentinel = "### 面接官の質問:"
66
  sentinelLoc = decoded.find(sentinel)
67
  if sentinelLoc >= 0:
68
  result = decoded[sentinelLoc + len(sentinel):]
69
+ # 最初の改行までを取得(1つの質問だけ)
70
+ result = result.split('\n')[0] if '\n' in result else result
71
+ return result.replace("<NL>", "\n").strip()
72
  else:
73
  return 'Warning: Expected prompt template to be emitted. Ignoring output.'
74