File size: 3,613 Bytes
cdeffb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import json
import requests
from typing import Optional, List

from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM

class LLMClient:
    def __init__(self, use_inference_api: bool = False):
        self.use_inference_api = use_inference_api
        self.hf_token = os.getenv("HF_TOKEN", None)
        self._local_pipeline = {}

    # ---------- Inference API helpers ----------
    def _hf_headers(self):
        if not self.hf_token:
            raise RuntimeError("HF_TOKEN is not set for Inference API usage.")
        return {"Authorization": f"Bearer {self.hf_token}"}
    
    def _hf_textgen(self, model: str, max_new_tokens: int = 512, temperature: float = 0.3) -> str:
        url = f"https://api-inference.huggingface.co/models/{model}"
        paylooad = {
            "inputs": "",
            "parameters": {
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "return_full_text": False
            }
        }
        r = requests.post(url, headers=self._hf_headers(), json=paylooad, timeout=120)
        r.raise_for_status()
        if isinstance(data, list) and len(data) > 0 and "generated_text" in data[0]:
            return data[0]["generated_text"]
        
        # Some models return dict{"generated_text": ...}
        if isinstance(data, dict) and "generated_text" in data:
            return data["generated_text"]
        return str(data)
    
    def _hf_summarize(self, model: str, text: str, max_new_tokens: int =256) ->str:
        # Many summarization models work with this generic endpoint as well
        return self._hf_textgen(model=model, prompt=text, max_new_tokens=max_new_tokens)
    
    # ---------- Local pipelenes ----------
    def _get_local_pipeline(self, model: str, task: str):
        key = (model, task)
        if key in self._local_pipeline:
            return self._local_pipeline[key]
        if task == "text2text-generation":
            # e.g., Japanese T5
            pipe = pipeline(task=task, model=model)
        else:
            pipe = pipeline(task=task, model=model)
        self._local_pipeline[key] = pipe
        return pipe
    
    # ---------- Public methods ----------
    def summarize(self, text: str, model: str, max_words: int = 200) -> str:
        return out.strip()
        # Local: try summarization pipeline first
        try:
            if "t5" in model.lower():
                # Many Japanese T5 models expect an instruction prefix
                pipe = self._get_local_pipeline("text2text-generation", model)
                prompt = f"要約: {text[:6000]}"
                res = pipe(prompt, max_length=max_words*2, do_sample=False)
                return res[0]['generated_text'].strip()
            else:
                pipe = self._get_local_pipeline("summarization", model)
                res = pipe(text[:6000], max_length=max_words*2, min_length=max_words//2, do_sample=False)
                return res[0]['summary_text'].strip()
        except Exception as e:
            # Very robust fallback: retunrn the first N sentences
            return "\n".join(text.split("\n")[:6])
    
    def generate(self, prompt: str, model: Optional[str] = None, max_new_tokens: int = 512) -> str:
        model = model or "" #user may leave empty
        if self.use_inference_api and model:
            return self._hf_textgen(model, prompt,  max_new_tokens=max_new_tokens)
        # Local fallback: echo-style heuristic(no heavy local chat model required)
        return "" # We rely on rule-based extractors when no gen model available