MarcoLeung052 commited on
Commit
421fc88
·
verified ·
1 Parent(s): c739de4

Update backend/ai_output.py

Browse files
Files changed (1) hide show
  1. backend/ai_output.py +14 -6
backend/ai_output.py CHANGED
@@ -1,7 +1,6 @@
1
- # backend/ai_output.py
2
-
3
  from fastapi import HTTPException
4
  from .model_loader import model, tokenizer
 
5
 
6
  def run_ai_output(input_text: str):
7
 
@@ -12,11 +11,21 @@ def run_ai_output(input_text: str):
12
  raise HTTPException(status_code=400, detail="輸入過長,請限制在 512 字元內")
13
 
14
  try:
15
- input_ids = tokenizer.encode(input_text, return_tensors='pt', truncation=True)
 
 
 
 
 
 
 
 
 
16
 
17
  output = model.generate(
18
- input_ids,
19
- max_length=len(input_text) + 150,
 
20
  num_return_sequences=3,
21
  no_repeat_ngram_size=3,
22
  do_sample=True,
@@ -34,7 +43,6 @@ def run_ai_output(input_text: str):
34
 
35
  completions = sorted(list(set(completions)), key=len, reverse=True)
36
 
37
- # ⭐ 統一回傳格式(AI skill)
38
  return {
39
  "type": "ai-multi-options",
40
  "options": completions or [input_text]
 
 
 
1
  from fastapi import HTTPException
2
  from .model_loader import model, tokenizer
3
+ import torch
4
 
5
  def run_ai_output(input_text: str):
6
 
 
11
  raise HTTPException(status_code=400, detail="輸入過長,請限制在 512 字元內")
12
 
13
  try:
14
+ # 正確 attention_mask
15
+ encoded = tokenizer(
16
+ input_text,
17
+ return_tensors="pt",
18
+ truncation=True,
19
+ padding=False
20
+ )
21
+
22
+ input_ids = encoded["input_ids"]
23
+ attention_mask = encoded["attention_mask"]
24
 
25
  output = model.generate(
26
+ input_ids=input_ids,
27
+ attention_mask=attention_mask, # 加上這行
28
+ max_length=input_ids.shape[1] + 150,
29
  num_return_sequences=3,
30
  no_repeat_ngram_size=3,
31
  do_sample=True,
 
43
 
44
  completions = sorted(list(set(completions)), key=len, reverse=True)
45
 
 
46
  return {
47
  "type": "ai-multi-options",
48
  "options": completions or [input_text]