Spaces:
Runtime error
Runtime error
Commit
·
0e7db3f
1
Parent(s):
8b1de88
Update components/custom_llm.py
Browse files- components/custom_llm.py +6 -2
components/custom_llm.py
CHANGED
|
@@ -17,9 +17,13 @@ def format_captions(text):
|
|
| 17 |
def custom_chain():
|
| 18 |
API_TOKEN = os.environ['HF_INFER_API']
|
| 19 |
|
| 20 |
-
prompt = PromptTemplate.from_template("<s><INST>Given the below template, create a list of image generation prompt with maximum 5 words for each number\n\n{template}<INST> ")
|
| 21 |
|
| 22 |
-
cap_llm = CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
return {"template":lambda x:format_captions(x)} | prompt | cap_llm
|
| 25 |
|
|
|
|
| 17 |
def custom_chain():
|
| 18 |
API_TOKEN = os.environ['HF_INFER_API']
|
| 19 |
|
| 20 |
+
# prompt = PromptTemplate.from_template("<s><INST>Given the below template, create a list of image generation prompt with maximum 5 words for each number\n\n{template}<INST> ")
|
| 21 |
|
| 22 |
+
# cap_llm = CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|"])
|
| 23 |
+
prompt = PromptTemplate.from_template("<s><INST>Given the below template, for each number, create a detailed description with maximum one word\n\n{template}<INST> ")
|
| 24 |
+
|
| 25 |
+
cap_llm = CustomLLM(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_type='text-generation', api_token=API_TOKEN, stop=["\n<|"], temperature=0.7)
|
| 26 |
+
|
| 27 |
|
| 28 |
return {"template":lambda x:format_captions(x)} | prompt | cap_llm
|
| 29 |
|