habdine commited on
Commit
bfc54a9
·
verified ·
1 Parent(s): ea465a5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +68 -81
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,11 +1,10 @@
1
- from threading import Thread
2
- from typing import Iterator, List, Tuple
3
 
4
  import gradio as gr
5
  from gradio.themes import Soft
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  TEAM_LOGO_URL = "http://nlp.polytechnique.fr/static/images/logo_dascim.png"
11
  PROTEIN_VISUAL_URL = "https://cas-bridge.xethub.hf.co/xet-bridge-us/68e677c594d3f20bbeecf13c/7cff6ae021d7c518ee4e2fcb70490516ad9e4999ec75c6a5dd164cc6ca64ae30?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251023%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251023T094659Z&X-Amz-Expires=3600&X-Amz-Signature=6a7598d77a46df971e88e1f378bc5e06794a3893f31319a6ab3431e4323d755c&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=66448b4fecac3bc79b26304f&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.png%3B+filename%3D%22model.png%22%3B&response-content-type=image%2Fpng&x-id=GetObject&Expires=1761216419&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTIxNjQxOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OGU2NzdjNTk0ZDNmMjBiYmVlY2YxM2MvN2NmZjZhZTAyMWQ3YzUxOGVlNGUyZmNiNzA0OTA1MTZhZDllNDk5OWVjNzVjNmE1ZGQxNjRjYzZjYTY0YWUzMCoifV19&Signature=YjrX1ZF%7EX1qw-m2nWOY8AxdSXwbrsidvlTZ5YWXZx3UPv0my0u68lWcpWIpIxzkGeWTtWPvlCfMcmnpmmwS2wHexorhgq9c7%7E3Ghw20evO0EMPvHBwP4vWYmXW8nHBqqqbw8Qy1pojDm9TvXV19O4-fCFxPi1aQ5FOTC2Kmn9gKxW%7EAN7vkWnfhU8QcCf18139hMbUvh9YoJ%7EesOWXoCFWgAbyz%7Eroajt5e3oM9b-IsU%7E2-UzMZ4%7EMA2MSOFmg487bhZDbr2IMD15-8O0jzWu3qyO3T1H06S-9kTdI%7EC6AYtXUY8YtSWKw%7EBzhARjXK6%7EuZ3c3kE1V7%7EdnLl1YM-2w__&Key-Pair-Id=K2L8F4GPSG1IFC"
@@ -17,7 +16,7 @@ PROTEIN_HERO = f"""
17
  """
18
 
19
  DESCRIPTION = f"""\
20
- ### Prot2Text-V2 Demo
21
 
22
  {PROTEIN_HERO}
23
 
@@ -45,22 +44,17 @@ EXAMPLE_SEQUENCES = [
45
  [
46
  "MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA"
47
  ],
48
- [
49
- "MCYSANGNTFLIVDNTQKRIPEEKKPDFVRENVGDLDGVIFVELVDGKYFMDYYNRDGSMAAFCGNGARAFSQYLIDRGWIKEKEFTFLSRAGEIKVIVDDSIWVRMPGVSEKKEMKVDGYEGYFVVVGVPHFVMEVKGIDELDVEKLGRDLRYKTGANVDFYEVLPDRLKVRTYERGVERETKACGTGVTSVFVVYRDKTGAKEVKIQVPGGTLFLKEENGEIFLRGDVKRCSEE"
50
- ],
51
- [
52
- "MTQEERFEQRIAQETAIEPQDWMPDAYRKTLIRQIGQHAHSEIVGMLPEGNWITRAPTLRRKAILLAKVQDEAGHGLYLYSAAETLGCAREDIYQKMLDGRMKYSSIFNYPTLSWADIGVIGWLVDGAAIVNQVALCRTSYGPYARAMVKICKEESFHQRQGFEACMALAQGSEAQKQMLQDAINRFWWPALMMFGPNDDNSPNSARSLTWKIKRFTNDELRQRFVDNTVPQVEMLGMTVPDPDLHFDTESGHYRFGEIDWQEFNEVINGRGICNQERLDAKRKAWEEGTWVREAALAHAQKQHARKVA"
53
- ],
54
- [
55
- "MTTRMIILNGGSSAGKSGIVRCLQSVLPEPWLAFGVDSLIEAMPLKMQSAEGGIEFDADGGVSIGPEFRALEGAWAEGVVAMARAGARIIIDDVFLGGAAAQERWRSFVGDLDVLWVGVRCDGAVAEGRETARGDRVAGMAAKQAYVVHEGVEYDVEVDTTHKESIECAWAIAAHVVP"
56
- ],
57
  ]
58
 
59
  MAX_MAX_NEW_TOKENS = 256
60
  DEFAULT_MAX_NEW_TOKENS = 100
61
 
62
 
63
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
64
 
65
  system_message = (
66
  "You are a scientific assistant specialized in protein function "
@@ -73,17 +67,20 @@ placeholder = '<|reserved_special_token_1|>'
73
 
74
  esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
75
  llama_tokenizer = AutoTokenizer.from_pretrained(
76
- pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
77
- pad_token='<|reserved_special_token_0|>'
 
 
 
 
 
78
  )
79
- model = AutoModelForCausalLM.from_pretrained('xiao-fei/Prot2Text-V2-11B-Instruct-hf',
80
- trust_remote_code=True,
81
- torch_dtype=torch.bfloat16,).to(device)
82
  model.eval()
83
 
84
 
85
  @spaces.GPU(duration=90)
86
- def stream_response(
87
  message: str,
88
  max_new_tokens: int = 1024,
89
  do_sample: bool = False,
@@ -91,52 +88,42 @@ def stream_response(
91
  top_p: float = 0.9,
92
  top_k: int = 50,
93
  repetition_penalty: float = 1.2,
94
- ) -> Iterator[str]:
95
-
96
-
97
- streamer = TextIteratorStreamer(llama_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
98
-
99
- user_message = "Sequence embeddings: " + placeholder * (len(message)+2)
100
  tokenized_prompt = llama_tokenizer.apply_chat_template(
101
  [
102
  {"role": "system", "content": system_message},
103
- {"role": "user", "content": user_message}
104
- ],
105
- add_generation_prompt=True,
106
- tokenize=True,
107
- return_tensors="pt",
108
- return_dict=True
109
- )
110
- tokenized_sequence = esm_tokenizer(
111
- message,
112
- return_tensors="pt"
113
- )
114
- model.eval()
115
- generate_kwargs = dict(
116
- inputs=tokenized_prompt["input_ids"].to(model.device),
117
- attention_mask=tokenized_prompt["attention_mask"].to(model.device),
118
- protein_input_ids=tokenized_sequence["input_ids"].to(model.device),
119
- protein_attention_mask=tokenized_sequence["attention_mask"].to(model.device),
120
- eos_token_id=128009,
121
- pad_token_id=128002,
122
- return_dict_in_generate=False,
123
- num_beams=1,
124
- # device=device,
125
- streamer=streamer,
126
- max_new_tokens=max_new_tokens,
127
- do_sample=do_sample,
128
- top_p=top_p,
129
- top_k=top_k,
130
- temperature=temperature,
131
- repetition_penalty=repetition_penalty,
132
  )
133
- t = Thread(target=model.generate, kwargs=generate_kwargs)
134
- t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- outputs = []
137
- for text in streamer:
138
- outputs.append(text)
139
- yield "".join(outputs)
140
 
141
 
142
  ChatHistory = List[Tuple[str, str]]
@@ -160,29 +147,30 @@ def handle_submit(
160
  conversation = history.copy()
161
  conversation.append((message, ""))
162
 
163
- for partial_response in stream_response(
164
- message=message,
165
- max_new_tokens=max_new_tokens,
166
- do_sample=do_sample,
167
- temperature=temperature,
168
- top_p=top_p,
169
- top_k=top_k,
170
- repetition_penalty=repetition_penalty,
171
- ):
172
- conversation[-1] = (message, partial_response)
173
- snapshot = conversation.copy()
174
- yield snapshot, snapshot, gr.update(value="")
 
 
 
 
175
 
176
 
177
  def clear_conversation():
178
  empty_history: ChatHistory = []
179
  return empty_history, empty_history, gr.update(value="")
180
 
181
- theme = Soft(
182
- primary_hue="slate",
183
- secondary_hue="stone",
184
- neutral_hue="gray",
185
- )
186
 
187
  with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
188
  with gr.Row(equal_height=True):
@@ -199,8 +187,6 @@ with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
199
  gr.Markdown(DESCRIPTION)
200
  with gr.Column(scale=7, min_width=400, elem_classes="interaction-column"):
201
  history_state = gr.State([])
202
-
203
-
204
  chatbot = gr.Chatbot(
205
  label="Generated Function",
206
  height=350,
@@ -303,9 +289,10 @@ with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
303
  )
304
  with gr.Accordion("Model & usage notes", open=False):
305
  gr.Markdown(
306
- "- **Model stack**: Facebook ESM2 encoder + Llama 3.1 8B instruction-tuned decoder.\n"
307
  "- **Token budget**: the generator truncates after the configured `Max new tokens`.\n"
308
  "- **Attribution**: Outputs are predictions; validate experimentally before publication.\n"
 
309
  )
310
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
311
 
 
1
+ from typing import List, Tuple
 
2
 
3
  import gradio as gr
4
  from gradio.themes import Soft
5
  import spaces
6
  import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  TEAM_LOGO_URL = "http://nlp.polytechnique.fr/static/images/logo_dascim.png"
10
  PROTEIN_VISUAL_URL = "https://cas-bridge.xethub.hf.co/xet-bridge-us/68e677c594d3f20bbeecf13c/7cff6ae021d7c518ee4e2fcb70490516ad9e4999ec75c6a5dd164cc6ca64ae30?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251023%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251023T094659Z&X-Amz-Expires=3600&X-Amz-Signature=6a7598d77a46df971e88e1f378bc5e06794a3893f31319a6ab3431e4323d755c&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=66448b4fecac3bc79b26304f&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.png%3B+filename%3D%22model.png%22%3B&response-content-type=image%2Fpng&x-id=GetObject&Expires=1761216419&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTIxNjQxOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OGU2NzdjNTk0ZDNmMjBiYmVlY2YxM2MvN2NmZjZhZTAyMWQ3YzUxOGVlNGUyZmNiNzA0OTA1MTZhZDllNDk5OWVjNzVjNmE1ZGQxNjRjYzZjYTY0YWUzMCoifV19&Signature=YjrX1ZF%7EX1qw-m2nWOY8AxdSXwbrsidvlTZ5YWXZx3UPv0my0u68lWcpWIpIxzkGeWTtWPvlCfMcmnpmmwS2wHexorhgq9c7%7E3Ghw20evO0EMPvHBwP4vWYmXW8nHBqqqbw8Qy1pojDm9TvXV19O4-fCFxPi1aQ5FOTC2Kmn9gKxW%7EAN7vkWnfhU8QcCf18139hMbUvh9YoJ%7EesOWXoCFWgAbyz%7Eroajt5e3oM9b-IsU%7E2-UzMZ4%7EMA2MSOFmg487bhZDbr2IMD15-8O0jzWu3qyO3T1H06S-9kTdI%7EC6AYtXUY8YtSWKw%7EBzhARjXK6%7EuZ3c3kE1V7%7EdnLl1YM-2w__&Key-Pair-Id=K2L8F4GPSG1IFC"
 
16
  """
17
 
18
  DESCRIPTION = f"""\
19
+ ### Prot2Text-V2 Demo (Spaces Edition)
20
 
21
  {PROTEIN_HERO}
22
 
 
44
  [
45
  "MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA"
46
  ],
 
 
 
 
 
 
 
 
 
47
  ]
48
 
49
  MAX_MAX_NEW_TOKENS = 256
50
  DEFAULT_MAX_NEW_TOKENS = 100
51
 
52
 
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ if device.type == "cuda":
55
+ dtype = torch.bfloat16 if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() else torch.float16
56
+ else:
57
+ dtype = torch.float32
58
 
59
  system_message = (
60
  "You are a scientific assistant specialized in protein function "
 
67
 
68
  esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
69
  llama_tokenizer = AutoTokenizer.from_pretrained(
70
+ pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
71
+ pad_token='<|reserved_special_token_0|>',
72
+ )
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ "xiao-fei/Prot2Text-V2-11B-Instruct-hf",
75
+ trust_remote_code=True,
76
+ torch_dtype=dtype,
77
  )
78
+ model = model.to(device)
 
 
79
  model.eval()
80
 
81
 
82
  @spaces.GPU(duration=90)
83
+ def generate_response(
84
  message: str,
85
  max_new_tokens: int = 1024,
86
  do_sample: bool = False,
 
88
  top_p: float = 0.9,
89
  top_k: int = 50,
90
  repetition_penalty: float = 1.2,
91
+ ) -> str:
92
+ user_message = "Sequence embeddings: " + placeholder * (len(message) + 2)
 
 
 
 
93
  tokenized_prompt = llama_tokenizer.apply_chat_template(
94
  [
95
  {"role": "system", "content": system_message},
96
+ {"role": "user", "content": user_message},
97
+ ],
98
+ add_generation_prompt=True,
99
+ tokenize=True,
100
+ return_tensors="pt",
101
+ return_dict=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  )
103
+ tokenized_sequence = esm_tokenizer(message, return_tensors="pt")
104
+
105
+ with torch.inference_mode():
106
+ generated = model.generate(
107
+ inputs=tokenized_prompt["input_ids"].to(device),
108
+ attention_mask=tokenized_prompt["attention_mask"].to(device),
109
+ protein_input_ids=tokenized_sequence["input_ids"].to(device),
110
+ protein_attention_mask=tokenized_sequence["attention_mask"].to(device),
111
+ eos_token_id=128009,
112
+ pad_token_id=128002,
113
+ return_dict_in_generate=False,
114
+ num_beams=1,
115
+ max_new_tokens=max_new_tokens,
116
+ do_sample=do_sample,
117
+ top_p=top_p,
118
+ top_k=top_k,
119
+ temperature=temperature,
120
+ repetition_penalty=repetition_penalty,
121
+ )
122
 
123
+ prompt_len = tokenized_prompt["input_ids"].shape[-1]
124
+ response_tokens = generated[0, prompt_len:]
125
+ response_text = llama_tokenizer.decode(response_tokens, skip_special_tokens=True)
126
+ return response_text.strip()
127
 
128
 
129
  ChatHistory = List[Tuple[str, str]]
 
147
  conversation = history.copy()
148
  conversation.append((message, ""))
149
 
150
+ try:
151
+ response_text = generate_response(
152
+ message=message,
153
+ max_new_tokens=max_new_tokens,
154
+ do_sample=do_sample,
155
+ temperature=temperature,
156
+ top_p=top_p,
157
+ top_k=top_k,
158
+ repetition_penalty=repetition_penalty,
159
+ )
160
+ except Exception as exc:
161
+ response_text = f"⚠️ Generation failed: {exc}"
162
+
163
+ conversation[-1] = (message, response_text)
164
+ snapshot = conversation.copy()
165
+ yield snapshot, snapshot, gr.update(value="")
166
 
167
 
168
  def clear_conversation():
169
  empty_history: ChatHistory = []
170
  return empty_history, empty_history, gr.update(value="")
171
 
172
+
173
+ theme = Soft(primary_hue="slate", secondary_hue="stone", neutral_hue="gray")
 
 
 
174
 
175
  with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
176
  with gr.Row(equal_height=True):
 
187
  gr.Markdown(DESCRIPTION)
188
  with gr.Column(scale=7, min_width=400, elem_classes="interaction-column"):
189
  history_state = gr.State([])
 
 
190
  chatbot = gr.Chatbot(
191
  label="Generated Function",
192
  height=350,
 
289
  )
290
  with gr.Accordion("Model & usage notes", open=False):
291
  gr.Markdown(
292
+ "- **Model stack**: Facebook ESM2 encoder + Prot2Text-V2 11B instruction-tuned decoder.\n"
293
  "- **Token budget**: the generator truncates after the configured `Max new tokens`.\n"
294
  "- **Attribution**: Outputs are predictions; validate experimentally before publication.\n"
295
+ "- **Privacy**: Sequence inputs stay within this session - export or clear as needed."
296
  )
297
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
298
 
requirements.txt CHANGED
@@ -4,5 +4,6 @@ gradio>=5.0.0
4
  huggingface-hub>=0.23.0
5
  mistral-common>=1.4.0
6
  sentencepiece>=0.1.99
 
7
  torch>=2.1.0
8
  transformers>=4.38.0
 
4
  huggingface-hub>=0.23.0
5
  mistral-common>=1.4.0
6
  sentencepiece>=0.1.99
7
+ spaces>=0.27.0
8
  torch>=2.1.0
9
  transformers>=4.38.0