Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import time | |
| import streamlit as st | |
| model_id = "google/codegemma-7b-it" | |
| def strip_bos_eos(text_tagged): | |
| m = re.match(r".*?(?<=<bos>)(.*)(?=<eos>).*?", text_tagged, flags=re.DOTALL) | |
| text_stripped = m.group(1) if m else text_tagged | |
| return text_stripped | |
| def load_models(): | |
| from dotenv import load_dotenv | |
| from transformers import GemmaTokenizer, AutoModelForCausalLM | |
| load_dotenv() | |
| _token = os.environ["HF_TOKEN"] | |
| _tokenizer = GemmaTokenizer.from_pretrained(model_id) | |
| _model = AutoModelForCausalLM.from_pretrained(model_id) | |
| return _token, _tokenizer, _model | |
| def process(_input_text): | |
| _token, _tokenizer, _model = load_models() | |
| input_ids = _tokenizer(_input_text, return_tensors="pt") | |
| _outputs = _model.generate(**input_ids, max_new_tokens=4092) | |
| _output_text = strip_bos_eos(_tokenizer.decode(_outputs[0])) | |
| return _output_text | |
| if __name__ == '__main__': | |
| load_models() | |
| st.title(model_id) | |
| input_text = st.text_input("Prompt") | |
| if st.button("Submit"): | |
| output_text = process(input_text) | |
| st.write(output_text) | |