drixo commited on
Commit
2d02001
Β·
verified Β·
1 Parent(s): 62a509e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -16
app.py CHANGED
@@ -1,18 +1,54 @@
1
- # Load model directly
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
- tokenizer = AutoTokenizer.from_pretrained("LivSterling/rc-tutor-llama3-merged")
5
- model = AutoModelForCausalLM.from_pretrained("LivSterling/rc-tutor-llama3-merged")
6
- messages = [
7
- {"role": "user", "content": "Who are you?"},
8
- ]
9
- inputs = tokenizer.apply_chat_template(
10
- messages,
11
- add_generation_prompt=True,
12
- tokenize=True,
13
- return_dict=True,
14
- return_tensors="pt",
15
- ).to(model.device)
16
-
17
- outputs = model.generate(**inputs, max_new_tokens=40)
18
- print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ MODEL_ID = "LivSterling/rc-tutor-llama3-merged"
6
+
7
+ # Load model & tokenizer
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ MODEL_ID,
11
+ torch_dtype=torch.float16,
12
+ device_map="auto",
13
+ )
14
+
15
+ def chat_fn(message, history):
16
+ # Convert Gradio history to chat template format
17
+ messages = []
18
+ for user_msg, bot_msg in history:
19
+ messages.append({"role": "user", "content": user_msg})
20
+ messages.append({"role": "assistant", "content": bot_msg})
21
+
22
+ messages.append({"role": "user", "content": message})
23
+
24
+ inputs = tokenizer.apply_chat_template(
25
+ messages,
26
+ add_generation_prompt=True,
27
+ tokenize=True,
28
+ return_tensors="pt",
29
+ ).to(model.device)
30
+
31
+ outputs = model.generate(
32
+ inputs,
33
+ max_new_tokens=256,
34
+ do_sample=True,
35
+ temperature=0.7,
36
+ top_p=0.9,
37
+ )
38
+
39
+ response = tokenizer.decode(
40
+ outputs[0][inputs.shape[-1]:],
41
+ skip_special_tokens=True,
42
+ )
43
+
44
+ return response
45
+
46
+ # Gradio UI
47
+ demo = gr.ChatInterface(
48
+ fn=chat_fn,
49
+ title="RC Tutor LLaMA-3 Chatbot",
50
+ description="Powered by LivSterling/rc-tutor-llama3-merged",
51
+ )
52
+
53
+ if __name__ == "__main__":
54
+ demo.launch()