|
|
|
|
| import gradio as gr
|
| from transformers import AutoTokenizer, AutoModelForCausalLM
|
| import torch
|
|
|
|
|
| if torch.cuda.is_available():
|
| device = "cuda"
|
| elif torch.backends.mps.is_available():
|
| device = "mps"
|
| else:
|
| device = "cpu"
|
|
|
|
|
|
|
| model_name = "upstage/TinySolar-248m-4k"
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
| model.to(device)
|
|
|
|
|
| model.eval()
|
|
|
|
|
| def text_streamer(prompt: str, temperature: float = 0.5):
|
| """
|
| Generate a continuation of the input prompt using the pre-trained language model.
|
|
|
| Args:
|
| prompt (str): The input prompt to generate text from.
|
| temperature (float, optional): The temperature for sampling, controlling randomness in the output.
|
| Default is 0.5.
|
|
|
| Returns:
|
| str: The generated text that continues the prompt.
|
| """
|
|
|
| if not prompt.strip():
|
| return "No prompt given. Please enter a valid prompt to continue."
|
|
|
|
|
| max_input_length = 512
|
|
|
|
|
| inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True,
|
| max_length=max_input_length).to(device)
|
|
|
|
|
| attention_mask = inputs['attention_mask'] if 'attention_mask' in inputs else None
|
|
|
|
|
| with torch.no_grad():
|
| outputs = model.generate(
|
| inputs['input_ids'],
|
| attention_mask=attention_mask,
|
| max_length=200,
|
| num_return_sequences=1,
|
| no_repeat_ngram_size=2,
|
| pad_token_id=tokenizer.eos_token_id,
|
| temperature=temperature,
|
| do_sample=True
|
| )
|
|
|
|
|
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
| return generated_text
|
|
|
|
|
| def start_gradio_interface():
|
| """
|
| Set up and launch the Gradio interface for text generation.
|
|
|
| This function creates a Gradio interface that allows users to input a prompt and
|
| generate text based on the input using a pre-trained model.
|
| """
|
| interface = gr.Interface(
|
| fn=text_streamer,
|
| inputs=[
|
| gr.Textbox(placeholder="Try 'I am in love with... or I went to...'",
|
| label="Enter your prompt"),
|
| gr.Slider(minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Temperature")
|
| ],
|
| outputs="text",
|
| title="AI Text Streamer: Complete Your Thoughts with AI",
|
| description=("This AI-powered text streamer helps you finish your sentences. "
|
| "Simply provide half of a sentence, and the model will generate the rest using the autoregressive "
|
| "text generation model upstage/TinySolar-248m-4k model. "
|
| "Perfect for creative writing, brainstorming, and expanding ideas!"),
|
| flagging_mode="never"
|
| )
|
| interface.launch()
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| start_gradio_interface()
|
|
|