MahiH commited on
Commit
1dd75d9
·
1 Parent(s): e1daf26

Add application file

Browse files
Files changed (2) hide show
  1. app.py +35 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # Load model + tokenizer
6
+ model_id = "MahiH/dialogpt-finetuned-chatbot"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
8
+ model = AutoModelForCausalLM.from_pretrained(model_id)
9
+ model.eval()
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+
13
+ # Inference function
14
+ def chat(prompt):
15
+ input_text = f"Human: {prompt}\nAssistant: "
16
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
17
+
18
+ with torch.no_grad():
19
+ output_ids = model.generate(
20
+ input_ids,
21
+ max_new_tokens=100,
22
+ do_sample=True,
23
+ top_k=50,
24
+ top_p=0.95,
25
+ temperature=0.8,
26
+ pad_token_id=tokenizer.eos_token_id
27
+ )
28
+
29
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
30
+ return response.split("Assistant:")[-1].strip()
31
+
32
+ # Set up Gradio app (no UI, just API)
33
+ app = gr.Interface(fn=chat, inputs=gr.Text(), outputs=gr.Text())
34
+
35
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # requirements.txt
2
+ transformers
3
+ torch
4
+ gradio