MarcoLeung052 commited on
Commit
c770c0c
·
verified ·
1 Parent(s): dbeb65b

Update backend/ai_output.py

Browse files
Files changed (1) hide show
  1. backend/ai_output.py +43 -16
backend/ai_output.py CHANGED
@@ -1,25 +1,52 @@
1
  # backend/ai_output.py
2
 
3
- import os
 
4
 
5
- BASE_DIR = os.path.dirname(os.path.abspath(__file__))
 
6
 
7
- def run_ai_output(field_name: str):
 
8
  """
9
- 讀取 backend/skills/<field_name>_ai/<field_name>.md
10
- 目前先回傳內容(之後你可以在這裡接 LLM)
11
  """
12
- folder = os.path.join(BASE_DIR, "skills", f"{field_name}_ai")
13
- file_path = os.path.join(folder, f"{field_name}.md")
14
 
15
- if not os.path.exists(file_path):
16
- # 沒有對應 skill → fallback
17
- return f"[AI] {field_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- with open(file_path, "r", encoding="utf-8") as f:
20
- content = f.read().strip()
21
 
22
- # 之後你可以在這裡:
23
- # response = call_llm(prompt=content)
24
- # return response
25
- return content
 
1
  # backend/ai_output.py
2
 
3
+ import torch
4
+ from fastapi import HTTPException
5
 
6
+ # 你原本在 api_server.py 裡的 model / tokenizer
7
+ from api_server import model, tokenizer
8
 
9
+
10
+ def run_ai_output(input_text: str):
11
  """
12
+ 使用 LLM 生成 3 個補全結果
 
13
  """
 
 
14
 
15
+ if model is None or tokenizer is None:
16
+ raise HTTPException(status_code=503, detail="AI 模型尚未準備就緒")
17
+
18
+ if len(input_text) > 512:
19
+ raise HTTPException(status_code=400, detail="輸入過長,請限制在 512 字元內")
20
+
21
+ try:
22
+ input_ids = tokenizer.encode(input_text, return_tensors='pt', truncation=True)
23
+
24
+ output = model.generate(
25
+ input_ids,
26
+ max_length=len(input_text) + 150,
27
+ num_return_sequences=3,
28
+ no_repeat_ngram_size=3,
29
+ do_sample=True,
30
+ top_k=50,
31
+ top_p=0.95,
32
+ temperature=0.8,
33
+ pad_token_id=tokenizer.eos_token_id
34
+ )
35
+
36
+ all_completions = []
37
+ for sequence in output:
38
+ generated_text = tokenizer.decode(sequence, skip_special_tokens=True)
39
+
40
+ if generated_text.startswith(input_text):
41
+ all_completions.append(generated_text)
42
+
43
+ unique_completions = sorted(list(set(all_completions)), key=len, reverse=True)
44
+
45
+ if not unique_completions:
46
+ return [input_text]
47
 
48
+ return unique_completions
 
49
 
50
+ except Exception as e:
51
+ print(f"AI 推論錯誤: {e}")
52
+ raise HTTPException(status_code=500, detail=f"AI 推論失敗:{str(e)[:50]}...")