Corin1998 commited on
Commit
9f5a128
·
verified ·
1 Parent(s): afd1e60

Create llm.py

Browse files
Files changed (1) hide show
  1. modules/llm.py +54 -0
modules/llm.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
4
+ import torch
5
+
6
+ _model = None
7
+ _tokenizer = None
8
+
9
+ def _load_local():
10
+ global _model, _tokenizer
11
+ model_id = os.getenv("HF_LOCAL_MODEL_ID", "google/flan-t5-base")
12
+ if "t5" in model_id or "flan" in model_id:
13
+ _tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ _model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
15
+ else:
16
+ _tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+ _model = AutoModelForCausalLM.from_pretrained(model_id)
18
+ if torch.cuda.is_available():
19
+ _model = _model.to("cuda")
20
+
21
+ def generate(system_prompt: str, user_prompt: str, temperature: float=0.4, max_new_tokens: int=512) -> str:
22
+ """
23
+ 単純なプロンプト→テキスト出力。必要に応じて強化してください。
24
+ """
25
+ use_api = os.getenv("USE_HF_INFERENCE_API", "false").lower() == "true"
26
+ if use_api:
27
+ # Inference APIを使う場合はrequestsでPOST(簡易版)
28
+ import requests
29
+ api_url = f"https://api-inference.huggingface.co/models/{os.getenv('HF_LOCAL_MODEL_ID')}"
30
+ headers = {"Authorization": f"Bearer {os.getenv('HF_API_TOKEN', '')}"}
31
+ payload = {"inputs": f"{system_prompt}\n\n{user_prompt}", "parameters": {"temperature": temperature, "max_new_tokens": max_new_tokens}}
32
+ r = requests.post(api_url, headers=headers, json=payload, timeout=120)
33
+ r.raise_for_status()
34
+ data = r.json()
35
+ if isinstance(data, list) and len(data) and "generated_text" in data[0]:
36
+ return data[0]["generated_text"]
37
+ elif isinstance(data, dict) and "generated_text" in data:
38
+ return data["generated_text"]
39
+ return str(data)
40
+
41
+ # ローカル実行(Transformers)
42
+ if _model is None:
43
+ _load_local()
44
+ prompt = f"{system_prompt}\n\n{user_prompt}".strip()
45
+ inputs = _tokenizer(prompt, return_tensors="pt")
46
+ if torch.cuda.is_available():
47
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
48
+ if hasattr(_model, "generate"):
49
+ with torch.no_grad():
50
+ out_ids = _model.generate(**inputs, do_sample=temperature>0, temperature=temperature, max_new_tokens=max_new_tokens)
51
+ text = _tokenizer.decode(out_ids[0], skip_special_tokens=True)
52
+ return text
53
+ else:
54
+ return "Model not supported."