TRaw commited on
Commit
732be18
·
1 Parent(s): a826d4f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ class ChatApp:
5
+ def __init__(self, model_id):
6
+ self.model_id = model_id
7
+ self.model = None
8
+ self.tokenizer = None
9
+
10
+ def load_model(self):
11
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
12
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
13
+
14
+ def generate_response(self, input_text):
15
+ input_ids = self.tokenizer.encode(input_text, return_tensors="pt")
16
+ output = self.model.generate(input_ids, max_length=100)
17
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
18
+ return response
19
+
20
+ def start_chat(self):
21
+ gr.Interface(fn=self.generate_response, inputs="text", outputs="text").launch()
22
+
23
+ def main():
24
+ model_id = "gpt2"
25
+ chat_app = ChatApp(model_id)
26
+ chat_app.load_model()
27
+ chat_app.start_chat()
28
+
29
+ if __name__ == "__main__":
30
+ main()