VirtualInsight commited on
Commit
28bae52
·
verified ·
1 Parent(s): 92854fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -39,15 +39,15 @@ print(f"EOS token ID: {EOS_TOKEN_ID}")
39
  @torch.no_grad()
40
  def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
41
  """
42
- Generates a chat-style response using the Lumen-Instruct model.
43
  """
44
- # Format the input as a structured conversation
45
  formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
46
 
47
  # Tokenize input
48
  input_ids = torch.tensor([tokenizer.encode(formatted_prompt).ids], dtype=torch.long, device=device)
49
 
50
- # Generate response with sampling
51
  output = generate(
52
  model,
53
  input_ids,
@@ -59,17 +59,20 @@ def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
59
  eos_token_id=EOS_TOKEN_ID,
60
  )
61
 
62
- # Decode full output text
63
  full_text = tokenizer.decode(output[0].tolist())
64
 
65
- # Extract only assistant’s part
66
  if "<|im_start|>assistant" in full_text:
67
  response = full_text.split("<|im_start|>assistant")[-1]
68
- if "<|im_end|>" in response:
69
- response = response.split("<|im_end|>")[0]
70
- return response.strip()
71
 
72
- return full_text.strip()
 
 
 
73
 
74
  # -----------------------------
75
  # Gradio Interface
@@ -88,7 +91,7 @@ demo = gr.Interface(
88
  )
89
 
90
  # -----------------------------
91
- # Launch
92
  # -----------------------------
93
  if __name__ == "__main__":
94
- demo.launch()
 
39
  @torch.no_grad()
40
  def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
41
  """
42
+ Generates a clean assistant-only response from the Lumen Instruct model.
43
  """
44
+ # Format input as a conversation prompt
45
  formatted_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
46
 
47
  # Tokenize input
48
  input_ids = torch.tensor([tokenizer.encode(formatted_prompt).ids], dtype=torch.long, device=device)
49
 
50
+ # Generate output
51
  output = generate(
52
  model,
53
  input_ids,
 
59
  eos_token_id=EOS_TOKEN_ID,
60
  )
61
 
62
+ # Decode full text
63
  full_text = tokenizer.decode(output[0].tolist())
64
 
65
+ # 🧹 Clean extraction of assistant’s reply only
66
  if "<|im_start|>assistant" in full_text:
67
  response = full_text.split("<|im_start|>assistant")[-1]
68
+ response = response.split("<|im_end|>")[0] if "<|im_end|>" in response else response
69
+ else:
70
+ response = full_text
71
 
72
+ # Remove potential leftover role tokens and clean spaces
73
+ response = response.replace("assistant", "").replace("user", "").strip()
74
+
75
+ return response
76
 
77
  # -----------------------------
78
  # Gradio Interface
 
91
  )
92
 
93
  # -----------------------------
94
+ # Launch Interface
95
  # -----------------------------
96
  if __name__ == "__main__":
97
+ demo.launch(share=True)