Vladislav Krasnov commited on
Commit
fd89608
·
1 Parent(s): 6f916b8

Update space 4

Browse files
Files changed (1) hide show
  1. app.py +59 -29
app.py CHANGED
@@ -2,8 +2,12 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
5
  model_name = "microsoft/phi-2"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
7
  model = AutoModelForCausalLM.from_pretrained(
8
  model_name,
9
  torch_dtype=torch.float32,
@@ -12,35 +16,61 @@ model = AutoModelForCausalLM.from_pretrained(
12
  )
13
 
14
  def respond(message, history):
15
- prompt = ""
16
- for human_msg, ai_msg in history:
17
- prompt += f"Human: {human_msg}\nAssistant: {ai_msg}\n"
18
- prompt += f"Human: {message}\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
21
- generate_ids = model.generate(
22
- inputs.input_ids,
23
- max_new_tokens=300,
24
- do_sample=True,
25
- temperature=0.7,
26
- top_p=0.9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
- output = tokenizer.batch_decode(
29
- generate_ids,
30
- skip_special_tokens=True,
31
- clean_up_tokenization_spaces=False
32
- )[0]
33
-
34
- answer = output.split("Assistant:")[-1].strip()
35
- return answer
36
-
37
- # Используем тип 'messages' для чата, как рекомендовано в предупреждении
38
- demo = gr.ChatInterface(
39
- fn=respond,
40
- title="LiveCoder LLM API",
41
- description="Модель Phi-2 для помощи в написании кода. Задавайте вопросы!",
42
- type="messages" # Устанавливаем тип в 'messages'
43
- )
44
 
45
- if __name__ == "__main__":
46
- demo.launch(share=False)
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Загрузка модели (остается такой же)
6
  model_name = "microsoft/phi-2"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
+ if tokenizer.pad_token is None:
9
+ tokenizer.pad_token = tokenizer.eos_token
10
+
11
  model = AutoModelForCausalLM.from_pretrained(
12
  model_name,
13
  torch_dtype=torch.float32,
 
16
  )
17
 
18
  def respond(message, history):
19
+ """Адаптированная функция для Blocks"""
20
+ history = history or []
21
+
22
+ # Формируем промпт
23
+ prompt = "Ты - ассистент для помощи в программировании. Отвечай кратко и по делу.\n\n"
24
+ for human, assistant in history:
25
+ prompt += f"Человек: {human}\nАссистент: {assistant}\n"
26
+ prompt += f"Человек: {message}\nАссистент:"
27
+
28
+ # Генерация
29
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
30
+
31
+ with torch.no_grad():
32
+ outputs = model.generate(
33
+ inputs.input_ids,
34
+ max_new_tokens=300,
35
+ temperature=0.7,
36
+ do_sample=True,
37
+ top_p=0.9
38
+ )
39
+
40
+ response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
41
+
42
+ # Обновляем историю
43
+ history.append((message, response))
44
+ return history, history, "" # Возвращаем историю и очищаем поле ввода
45
 
46
+ # СОЗДАЕМ ИНТЕРФЕЙС ВРУЧНУЮ
47
+ with gr.Blocks(title="LiveCoder LLM API", theme=gr.themes.Soft()) as demo:
48
+ gr.Markdown("# 🚀 LiveCoder LLM API")
49
+ gr.Markdown("Модель Phi-2 для помощи в написании кода")
50
+
51
+ chatbot = gr.Chatbot(height=400, label="Диалог")
52
+ msg = gr.Textbox(label="Ваш вопрос", placeholder="Введите вопрос по программированию...")
53
+ clear = gr.Button("Очистить чат")
54
+
55
+ # Состояние (история диалога)
56
+ state = gr.State([])
57
+
58
+ # Обработчики
59
+ def user_message(message, history):
60
+ return "", history + [[message, None]]
61
+
62
+ def bot_message(history):
63
+ message = history[-1][0]
64
+ # Вызываем функцию respond
65
+ new_history, _, _ = respond(message, history[:-1])
66
+ history[-1][1] = new_history[-1][1]
67
+ return history
68
+
69
+ # Привязка событий
70
+ msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
71
+ bot_message, chatbot, chatbot
72
  )
73
+
74
+ clear.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)