AoEiuV020 commited on
Commit
916d7eb
·
1 Parent(s): 88456ad

Update app.py with Gradio interface

Browse files
Files changed (2) hide show
  1. app.py +52 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import random
4
+ import gradio as gr
5
+
6
+ # Replace with your actual model path
7
+ transformers_model_path = "jingyaogong/MiniMind2"
8
+
9
+ # Load the tokenizer and model
10
+ tokenizer = AutoTokenizer.from_pretrained(transformers_model_path)
11
+ model = AutoModelForCausalLM.from_pretrained(transformers_model_path, trust_remote_code=True).eval()
12
+
13
+ def setup_seed(seed):
14
+ torch.manual_seed(seed)
15
+ torch.cuda.manual_seed_all(seed)
16
+ random.seed(seed)
17
+
18
+ def predict(prompt):
19
+ messages = []
20
+ max_seq_len = 128
21
+ history_cnt = 0
22
+ model_mode = 2
23
+ setup_seed(random.randint(0, 2048))
24
+ messages = messages[-history_cnt:] if history_cnt else []
25
+ messages.append({"role": "user", "content": prompt})
26
+ new_prompt = tokenizer.apply_chat_template(
27
+ messages,
28
+ tokenize=False,
29
+ add_generation_prompt=True
30
+ )[-max_seq_len - 1:] if model_mode != 0 else (tokenizer.bos_token + prompt)
31
+
32
+ with torch.no_grad():
33
+ x = torch.tensor(tokenizer(new_prompt)['input_ids'], device='cpu').unsqueeze(0)
34
+ outputs = model.generate(
35
+ x,
36
+ eos_token_id=tokenizer.eos_token_id,
37
+ max_new_tokens=max_seq_len,
38
+ temperature=0.7,
39
+ top_p=0.95,
40
+ pad_token_id=tokenizer.pad_token_id
41
+ )
42
+ return tokenizer.decode(outputs.squeeze()[x.shape[1]:].tolist(), skip_special_tokens=True)
43
+
44
+ iface = gr.Interface(
45
+ fn=predict,
46
+ inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
47
+ outputs="text",
48
+ title="MiniMind2 Chatbot",
49
+ description="Enter text and see the model's response."
50
+ )
51
+
52
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio