MindVR commited on
Commit
fad09d8
·
verified ·
1 Parent(s): 7e877bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -3
app.py CHANGED
@@ -4,6 +4,26 @@ from typing import List
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
9
  if HF_TOKEN:
@@ -17,13 +37,17 @@ model = AutoModelForCausalLM.from_pretrained(
17
  low_cpu_mem_usage=True,
18
  token=HF_TOKEN
19
  )
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
  model.to(device)
22
 
 
23
  def build_prompt(prompt: str, histories: List[str], new_message: str) -> str:
24
  prompt_text = prompt.strip() + "\n" if prompt else ""
25
- if histories:
26
- prompt_text += "\n".join(histories) + "\n"
 
 
 
 
27
  prompt_text += f"User: {new_message}\nAI:"
28
  return prompt_text
29
 
 
4
  from huggingface_hub import login
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
  import gradio as gr
7
+ from transformers import AutoTokenizer as SummarizerTokenizer, AutoModelForSeq2SeqLM
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Summarization model
12
+ summarizer_model_id = "facebook/bart-large-cnn"
13
+ summarizer_tokenizer = SummarizerTokenizer.from_pretrained(summarizer_model_id)
14
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained(summarizer_model_id)
15
+ summarizer_model.to(device)
16
+
17
+ def summarize_text(text: str, max_length=150) -> str:
18
+ inputs = summarizer_tokenizer([text], return_tensors="pt", max_length=1024, truncation=True).to(device)
19
+ summary_ids = summarizer_model.generate(
20
+ inputs['input_ids'],
21
+ num_beams=4,
22
+ max_length=max_length,
23
+ early_stopping=True
24
+ )
25
+ summary = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
26
+ return summary
27
 
28
  HF_TOKEN = os.environ.get("HF_TOKEN")
29
  if HF_TOKEN:
 
37
  low_cpu_mem_usage=True,
38
  token=HF_TOKEN
39
  )
 
40
  model.to(device)
41
 
42
+ # --- GIỮ LẠI CHỈ 1 HÀM build_prompt, ĐÃ BỔ SUNG SUMMARIZATION ---
43
  def build_prompt(prompt: str, histories: List[str], new_message: str) -> str:
44
  prompt_text = prompt.strip() + "\n" if prompt else ""
45
+ histories_text = "\n".join(histories) if histories else ""
46
+ # Tóm tắt nếu quá dài (tùy chỉnh ngưỡng này)
47
+ if len(histories_text) > 3000:
48
+ histories_text = summarize_text(histories_text, max_length=180)
49
+ if histories_text:
50
+ prompt_text += histories_text + "\n"
51
  prompt_text += f"User: {new_message}\nAI:"
52
  return prompt_text
53