| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| class ChatApp: | |
| def __init__(self, model_id): | |
| self.model_id = model_id | |
| self.model = None | |
| self.tokenizer = None | |
| def load_model(self): | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| def generate_response(self, input_text): | |
| input_ids = self.tokenizer.encode(input_text, return_tensors="pt") | |
| output = self.model.generate(input_ids, max_length=100) | |
| response = self.tokenizer.decode(output[0], skip_special_tokens=True) | |
| return response | |
| def start_chat(self): | |
| gr.Interface(fn=self.generate_response, inputs="text", outputs="text").launch() | |
| def main(): | |
| model_id = "gpt2" | |
| chat_app = ChatApp(model_id) | |
| chat_app.load_model() | |
| chat_app.start_chat() | |
| if __name__ == "__main__": | |
| main() |