MahiH commited on
Commit
1265f99
·
1 Parent(s): 93bd965

public access not working 2

Browse files
Files changed (1) hide show
  1. app.py +12 -6
app.py CHANGED
@@ -6,24 +6,30 @@ model_id = "MahiH/dialogpt-finetuned-chatbot"
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForCausalLM.from_pretrained(model_id)
8
  model.eval()
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model.to(device)
11
 
12
  def chat(prompt):
13
  input_text = f"Human: {prompt}\nAssistant: "
14
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
15
  with torch.no_grad():
16
- output_ids = model.generate(
17
- input_ids,
18
  max_new_tokens=100,
19
  do_sample=True,
20
  top_p=0.95,
21
  temperature=0.8,
22
  pad_token_id=tokenizer.eos_token_id
23
  )
24
- decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
25
  return decoded.split("Assistant:")[-1].strip()
26
 
27
- # API-only Gradio interface
28
  demo = gr.Interface(fn=chat, inputs="text", outputs="text")
29
- demo.launch(enable_queue=True)
 
 
 
 
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForCausalLM.from_pretrained(model_id)
8
  model.eval()
9
+
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model.to(device)
12
 
13
  def chat(prompt):
14
  input_text = f"Human: {prompt}\nAssistant: "
15
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
16
  with torch.no_grad():
17
+ outputs = model.generate(
18
+ **inputs,
19
  max_new_tokens=100,
20
  do_sample=True,
21
  top_p=0.95,
22
  temperature=0.8,
23
  pad_token_id=tokenizer.eos_token_id
24
  )
25
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
  return decoded.split("Assistant:")[-1].strip()
27
 
28
+ # Create Interface
29
  demo = gr.Interface(fn=chat, inputs="text", outputs="text")
30
+
31
+ # Enable queuing to support the REST API endpoint
32
+ demo.queue()
33
+
34
+ # Launch (no extra args needed)
35
+ demo.launch()