Spaces:
Sleeping
Sleeping
| import os | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| import torch | |
| _model = None | |
| _tokenizer = None | |
| def _load_local(): | |
| global _model, _tokenizer | |
| model_id = os.getenv("HF_LOCAL_MODEL_ID", "google/flan-t5-base") | |
| if "t5" in model_id or "flan" in model_id: | |
| _tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| _model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| else: | |
| _tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| _model = AutoModelForCausalLM.from_pretrained(model_id) | |
| if torch.cuda.is_available(): | |
| _model = _model.to("cuda") | |
| def generate(system_prompt: str, user_prompt: str, temperature: float=0.4, max_new_tokens: int=512) -> str: | |
| use_api = os.getenv("USE_HF_INFERENCE_API", "false").lower() == "true" | |
| if use_api: | |
| import requests | |
| api_url = f"https://api-inference.huggingface.co/models/{os.getenv('HF_LOCAL_MODEL_ID')}" | |
| headers = {"Authorization":f"Bearer{os.getenv('HF_API_TOKEN', '')}"} | |
| payload = {"inputs": f"{system_prompt}\n\n{user_prompt}", "parameters":{"temperature": temperature, "max_new_tokens": max_new_tokens}} | |
| r = requests.post(api_url, headers=headers, json=payload, timeout=120) | |
| r.raise_for_status() | |
| data = r.json() | |
| data = r.json() | |
| if isinstance(data, list) and data and "generated_text" in data[0]: | |
| return data[0]["generated_text"] | |
| if isinstance(data, dict) and "generated_text" in data: | |
| return data["generated_text"] | |
| return str(data) | |
| if _model is None: | |
| _load_local() | |
| prompt = f"{system_prompt}\n\n{user_prompt}".strip() | |
| inputs = _tokenizer(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out_ids = _model.generate(**inputs, do_sample=temperature>0, temperature=temperature, max_new_tokens=max_new_tokens) | |
| return _tokenizer.decode(out_ids[0], skip_special_tokens=True) | |