| | import torch |
| | import gradio as gr |
| | from transformers import GPT2LMHeadModel, GPT2TokenizerFast |
| |
|
| | |
| | MODEL_REPO = "i3-lab/i3-GPT2" |
| |
|
| | |
| | tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_REPO) |
| | model = GPT2LMHeadModel.from_pretrained(MODEL_REPO) |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model.to(device) |
| |
|
| | def generate_response(message, history): |
| | |
| | prompt = "" |
| | for user_msg, assistant_msg in history: |
| | prompt += f"User: {user_msg}\nAssistant: {assistant_msg}<|endoftext|>\n" |
| | prompt += f"User: {message}\nAssistant:" |
| |
|
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | output_tokens = model.generate( |
| | **inputs, |
| | max_new_tokens=150, |
| | do_sample=True, |
| | top_p=0.9, |
| | temperature=0.7, |
| | pad_token_id=tokenizer.eos_token_id, |
| | repetition_penalty=1.2 |
| | ) |
| | |
| | |
| | response = tokenizer.decode(output_tokens[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) |
| | |
| | |
| | clean_response = response.split("User:")[0].strip() |
| | return clean_response |
| |
|
| | |
| | demo = gr.ChatInterface( |
| | fn=generate_response, |
| | title="i3-GPT", |
| | examples=["Tell me a joke.", "What is the capital of France?", "How does a lightbulb work?"] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |