kokofixcomputers commited on
Commit
e36063a
Β·
1 Parent(s): d88cbed

Fix errors

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
- import markdown
5
 
6
  model_name = "deepseek-ai/deepseek-coder-1.3b-base"
7
 
@@ -9,17 +8,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
10
  model.eval()
11
 
12
- def respond(message, chat_history, max_tokens, temperature, top_p):
13
- chat_history = chat_history or []
14
- chat_history.append(("User", message))
 
15
 
16
- full_prompt = ""
17
- for speaker, text in chat_history:
18
- prefix = "User: " if speaker == "User" else "Assistant: "
19
- full_prompt += prefix + text + "\n"
20
- full_prompt += "Assistant: "
 
21
 
22
- inputs = tokenizer(full_prompt, return_tensors="pt")
23
  outputs = model.generate(
24
  **inputs,
25
  max_new_tokens=max_tokens,
@@ -28,29 +29,24 @@ def respond(message, chat_history, max_tokens, temperature, top_p):
28
  do_sample=True,
29
  pad_token_id=tokenizer.eos_token_id,
30
  )
31
- reply = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(full_prompt):].strip()
32
- chat_history.append(("Assistant", reply))
33
 
34
- formatted_history = []
35
- for i in range(0, len(chat_history), 2):
36
- user_msg = chat_history[i][1] if i < len(chat_history) else ""
37
- bot_msg = chat_history[i+1][1] if i+1 < len(chat_history) else ""
38
- # Render assistant message as markdown
39
- formatted_history.append([user_msg, gr.Markdown(bot_msg)])
40
 
41
- return formatted_history, ""
42
 
43
  with gr.Blocks() as demo:
44
  gr.Markdown("# DeepSeek Coder Chatbot")
45
 
46
- chatbot = gr.Chatbot()
47
  with gr.Row():
48
- user_input = gr.Textbox(show_label=False, placeholder="Enter your prompt here and press Enter")
49
  with gr.Row():
50
  max_tokens = gr.Slider(1, 1024, value=512, step=1, label="Max Tokens")
51
  temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature")
52
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
53
-
54
  def user_submit(text, history, max_tokens, temperature, top_p):
55
  if not text.strip():
56
  return history, ""
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  model_name = "deepseek-ai/deepseek-coder-1.3b-base"
6
 
 
8
  model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
9
  model.eval()
10
 
11
+ def respond(message, history, max_tokens, temperature, top_p):
12
+ history = history or []
13
+ # Append user message as dict with role and content
14
+ history.append({"role": "user", "content": message})
15
 
16
+ # Create prompt by concatenating conversation history as text
17
+ prompt = ""
18
+ for msg in history:
19
+ prefix = f"{msg['role'].capitalize()}: "
20
+ prompt += prefix + msg["content"] + "\n"
21
+ prompt += "Assistant: "
22
 
23
+ inputs = tokenizer(prompt, return_tensors="pt")
24
  outputs = model.generate(
25
  **inputs,
26
  max_new_tokens=max_tokens,
 
29
  do_sample=True,
30
  pad_token_id=tokenizer.eos_token_id,
31
  )
 
 
32
 
33
+ reply = tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()
34
+ # Append assistant response
35
+ history.append({"role": "assistant", "content": reply})
 
 
 
36
 
37
+ return history, ""
38
 
39
  with gr.Blocks() as demo:
40
  gr.Markdown("# DeepSeek Coder Chatbot")
41
 
42
+ chatbot = gr.Chatbot(type="messages")
43
  with gr.Row():
44
+ user_input = gr.Textbox(show_label=False, placeholder="Enter your prompt and press Enter")
45
  with gr.Row():
46
  max_tokens = gr.Slider(1, 1024, value=512, step=1, label="Max Tokens")
47
  temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.05, label="Temperature")
48
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
49
+
50
  def user_submit(text, history, max_tokens, temperature, top_p):
51
  if not text.strip():
52
  return history, ""