Spaces:
Runtime error
Runtime error
| from typing import Any, List, Mapping, Optional | |
| from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain_core.language_models.llms import LLM | |
| from typing import Literal | |
| import requests | |
| from langchain.prompts import PromptTemplate, ChatPromptTemplate | |
| from operator import itemgetter | |
| import re, os | |
| def format_captions(text): | |
| sentences = list(filter(None,[x.strip() for x in re.split(r'[^A-Za-z0-9 -]', text)])) | |
| print(len(sentences)) | |
| return "\n".join([f"{i+1}. {d}" for i,d in enumerate(sentences)]) | |
| def custom_chain(): | |
| API_TOKEN = os.environ['HF_INFER_API'] | |
| # prompt = PromptTemplate.from_template("<s><INST>Given the below template, create a list of image generation prompt with maximum 5 words for each number\n\n{template}<INST> ") | |
| # cap_llm = CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|"]) | |
| prompt = PromptTemplate.from_template("<s><INST>Given the below template, for each number, create a detailed description with maximum one word\n\n{template}<INST> ") | |
| cap_llm = CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|"], temperature=0.7) | |
| return {"template":lambda x:format_captions(x)} | prompt | cap_llm | |
| class CustomLLM(LLM): | |
| repo_id : str | |
| api_token : str | |
| model_type: Literal["text2text-generation", "text-generation"] | |
| max_new_tokens: int = None | |
| temperature: float = 0.001 | |
| timeout: float = None | |
| top_p: float = None | |
| top_k : int = None | |
| repetition_penalty : float = None | |
| stop : List[str] = [] | |
| def _llm_type(self) -> str: | |
| return "custom" | |
| def _call( | |
| self, | |
| prompt: str, | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> str: | |
| headers = {"Authorization": f"Bearer {self.api_token}"} | |
| API_URL = f"https://api-inference.huggingface.co/models/{self.repo_id}" | |
| parameters_dict = { | |
| 'max_new_tokens': self.max_new_tokens, | |
| 'temperature': self.temperature, | |
| 'timeout': self.timeout, | |
| 'top_p': self.top_p, | |
| 'top_k': self.top_k, | |
| 'repetition_penalty': self.repetition_penalty, | |
| 'stop':self.stop | |
| } | |
| if self.model_type == 'text-generation': | |
| parameters_dict["return_full_text"]=False | |
| data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}} | |
| data = requests.post(API_URL, headers=headers, json=data).json() | |
| return data[0]['generated_text'] |