SFT / src /streamlit_app.py
choco-conoz's picture
feat: change logic
dad4e7a
raw
history blame
2.73 kB
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# from huggingface_hub import notebook_login
# from unsloth import FastLanguageModel, is_bfloat16_supported
# model_id = "sentence-transformers/all-MiniLM-L6-v2"
# model_id = "sentence-transformers/xlm-r-base-en-ko-nli-ststb"
# model_id = "mistralai/Mistral-7B-Instruct-v0.1"
# model_id = "meta-llama/Llama-3.2-1B"
model_id = "choco-conoz/TwinLlama-3.1-8B"
processor = pipeline(
"text-generation",
model=model_id,
model_kwargs={
"torch_dtype": torch.float16,
"quantization_config": {"load_in_4bit": True},
"low_cpu_mem_usage": True,
},
)
terminators = [
processor.tokenizer.eos_token_id,
processor.tokenizer.convert_tokens_to_ids(""),
]
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(model_id)
# processor = pipeline(
# "text-generation",
# model=model,
# tokenizer=tokenizer,
# max_new_tokens=10
# )
def main():
st.title('Text Generator')
query = st.text_input('input your topic of interest')
alpaca_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{}
### Response:
{}
"""
if st.button("Send"):
user_prompt = alpaca_template.format(query, "")
print('user_prompt', user_prompt)
prompt = processor.tokenizer.apply_chat_template(
user_prompt, tokenize=False, add_generation_prompt=True)
# prompt = user_prompt
# outputs = processor(prompt)
outputs = processor(prompt, max_new_tokens=4096, eos_token_id=terminators, do_sample=True,
temperature=0.6, top_p=0.9
)
response = outputs[0]["generated_text"][len(prompt):]
st.write(response)
if __name__ == "__main__":
main()
# >>> old
# num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
# num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
# indices = np.linspace(0, 1, num_points)
# theta = 2 * np.pi * num_turns * indices
# radius = indices
# x = radius * np.cos(theta)
# y = radius * np.sin(theta)
# df = pd.DataFrame({
# "x": x,
# "y": y,
# "idx": indices,
# "rand": np.random.randn(num_points),
# })
# st.altair_chart(alt.Chart(df, height=700, width=700)
# .mark_point(filled=True)
# .encode(
# x=alt.X("x", axis=None),
# y=alt.Y("y", axis=None),
# color=alt.Color("idx", legend=None, scale=alt.Scale()),
# size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
# ))