habdine commited on
Commit
fb8d1cc
·
verified ·
1 Parent(s): 2955114

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -68
app.py CHANGED
@@ -1,10 +1,11 @@
1
- from typing import List, Tuple
 
2
 
3
  import gradio as gr
4
  from gradio.themes import Soft
5
  import spaces
6
  import torch
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
  TEAM_LOGO_URL = "http://nlp.polytechnique.fr/static/images/logo_dascim.png"
10
  PROTEIN_VISUAL_URL = "https://cas-bridge.xethub.hf.co/xet-bridge-us/68e677c594d3f20bbeecf13c/7cff6ae021d7c518ee4e2fcb70490516ad9e4999ec75c6a5dd164cc6ca64ae30?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251023%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251023T094659Z&X-Amz-Expires=3600&X-Amz-Signature=6a7598d77a46df971e88e1f378bc5e06794a3893f31319a6ab3431e4323d755c&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=66448b4fecac3bc79b26304f&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.png%3B+filename%3D%22model.png%22%3B&response-content-type=image%2Fpng&x-id=GetObject&Expires=1761216419&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTIxNjQxOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OGU2NzdjNTk0ZDNmMjBiYmVlY2YxM2MvN2NmZjZhZTAyMWQ3YzUxOGVlNGUyZmNiNzA0OTA1MTZhZDllNDk5OWVjNzVjNmE1ZGQxNjRjYzZjYTY0YWUzMCoifV19&Signature=YjrX1ZF%7EX1qw-m2nWOY8AxdSXwbrsidvlTZ5YWXZx3UPv0my0u68lWcpWIpIxzkGeWTtWPvlCfMcmnpmmwS2wHexorhgq9c7%7E3Ghw20evO0EMPvHBwP4vWYmXW8nHBqqqbw8Qy1pojDm9TvXV19O4-fCFxPi1aQ5FOTC2Kmn9gKxW%7EAN7vkWnfhU8QcCf18139hMbUvh9YoJ%7EesOWXoCFWgAbyz%7Eroajt5e3oM9b-IsU%7E2-UzMZ4%7EMA2MSOFmg487bhZDbr2IMD15-8O0jzWu3qyO3T1H06S-9kTdI%7EC6AYtXUY8YtSWKw%7EBzhARjXK6%7EuZ3c3kE1V7%7EdnLl1YM-2w__&Key-Pair-Id=K2L8F4GPSG1IFC"
@@ -16,7 +17,7 @@ PROTEIN_HERO = f"""
16
  """
17
 
18
  DESCRIPTION = f"""\
19
- ### Prot2Text-V2 Demo (Spaces Edition)
20
 
21
  {PROTEIN_HERO}
22
 
@@ -44,17 +45,22 @@ EXAMPLE_SEQUENCES = [
44
  [
45
  "MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA"
46
  ],
 
 
 
 
 
 
 
 
 
47
  ]
48
 
49
  MAX_MAX_NEW_TOKENS = 256
50
  DEFAULT_MAX_NEW_TOKENS = 100
51
 
52
 
53
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- if device.type == "cuda":
55
- dtype = torch.bfloat16 if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() else torch.float16
56
- else:
57
- dtype = torch.float32
58
 
59
  system_message = (
60
  "You are a scientific assistant specialized in protein function "
@@ -67,20 +73,17 @@ placeholder = '<|reserved_special_token_1|>'
67
 
68
  esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
69
  llama_tokenizer = AutoTokenizer.from_pretrained(
70
- pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
71
- pad_token='<|reserved_special_token_0|>',
72
- )
73
- model = AutoModelForCausalLM.from_pretrained(
74
- "xiao-fei/Prot2Text-V2-11B-Instruct-hf",
75
- trust_remote_code=True,
76
- torch_dtype=dtype,
77
  )
78
- model = model.to(device)
 
 
79
  model.eval()
80
 
81
 
82
  @spaces.GPU(duration=90)
83
- def generate_response(
84
  message: str,
85
  max_new_tokens: int = 1024,
86
  do_sample: bool = False,
@@ -88,42 +91,52 @@ def generate_response(
88
  top_p: float = 0.9,
89
  top_k: int = 50,
90
  repetition_penalty: float = 1.2,
91
- ) -> str:
92
- user_message = "Sequence embeddings: " + placeholder * (len(message) + 2)
 
 
 
 
93
  tokenized_prompt = llama_tokenizer.apply_chat_template(
94
  [
95
  {"role": "system", "content": system_message},
96
- {"role": "user", "content": user_message},
97
- ],
98
- add_generation_prompt=True,
99
- tokenize=True,
100
- return_tensors="pt",
101
- return_dict=True,
102
  )
103
- tokenized_sequence = esm_tokenizer(message, return_tensors="pt")
104
-
105
- with torch.inference_mode():
106
- generated = model.generate(
107
- inputs=tokenized_prompt["input_ids"].to(device),
108
- attention_mask=tokenized_prompt["attention_mask"].to(device),
109
- protein_input_ids=tokenized_sequence["input_ids"].to(device),
110
- protein_attention_mask=tokenized_sequence["attention_mask"].to(device),
111
- eos_token_id=128009,
112
- pad_token_id=128002,
113
- return_dict_in_generate=False,
114
- num_beams=1,
115
- max_new_tokens=max_new_tokens,
116
- do_sample=do_sample,
117
- top_p=top_p,
118
- top_k=top_k,
119
- temperature=temperature,
120
- repetition_penalty=repetition_penalty,
121
- )
 
 
 
 
 
 
122
 
123
- prompt_len = tokenized_prompt["input_ids"].shape[-1]
124
- response_tokens = generated[0, prompt_len:]
125
- response_text = llama_tokenizer.decode(response_tokens, skip_special_tokens=True)
126
- return response_text.strip()
127
 
128
 
129
  ChatHistory = List[Tuple[str, str]]
@@ -147,30 +160,29 @@ def handle_submit(
147
  conversation = history.copy()
148
  conversation.append((message, ""))
149
 
150
- try:
151
- response_text = generate_response(
152
- message=message,
153
- max_new_tokens=max_new_tokens,
154
- do_sample=do_sample,
155
- temperature=temperature,
156
- top_p=top_p,
157
- top_k=top_k,
158
- repetition_penalty=repetition_penalty,
159
- )
160
- except Exception as exc:
161
- response_text = f"⚠️ Generation failed: {exc}"
162
-
163
- conversation[-1] = (message, response_text)
164
- snapshot = conversation.copy()
165
- return snapshot, snapshot, gr.update(value="")
166
 
167
 
168
  def clear_conversation():
169
  empty_history: ChatHistory = []
170
  return empty_history, empty_history, gr.update(value="")
171
 
172
-
173
- theme = Soft(primary_hue="slate", secondary_hue="stone", neutral_hue="gray")
 
 
 
174
 
175
  with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
176
  with gr.Row(equal_height=True):
@@ -187,6 +199,8 @@ with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
187
  gr.Markdown(DESCRIPTION)
188
  with gr.Column(scale=7, min_width=400, elem_classes="interaction-column"):
189
  history_state = gr.State([])
 
 
190
  chatbot = gr.Chatbot(
191
  label="Generated Function",
192
  height=350,
@@ -289,10 +303,9 @@ with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
289
  )
290
  with gr.Accordion("Model & usage notes", open=False):
291
  gr.Markdown(
292
- "- **Model stack**: Facebook ESM2 encoder + Prot2Text-V2 11B instruction-tuned decoder.\n"
293
  "- **Token budget**: the generator truncates after the configured `Max new tokens`.\n"
294
  "- **Attribution**: Outputs are predictions; validate experimentally before publication.\n"
295
- "- **Privacy**: Sequence inputs stay within this session - export or clear as needed."
296
  )
297
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
298
 
 
1
+ from threading import Thread
2
+ from typing import Iterator, List, Tuple
3
 
4
  import gradio as gr
5
  from gradio.themes import Soft
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  TEAM_LOGO_URL = "http://nlp.polytechnique.fr/static/images/logo_dascim.png"
11
  PROTEIN_VISUAL_URL = "https://cas-bridge.xethub.hf.co/xet-bridge-us/68e677c594d3f20bbeecf13c/7cff6ae021d7c518ee4e2fcb70490516ad9e4999ec75c6a5dd164cc6ca64ae30?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20251023%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20251023T094659Z&X-Amz-Expires=3600&X-Amz-Signature=6a7598d77a46df971e88e1f378bc5e06794a3893f31319a6ab3431e4323d755c&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=66448b4fecac3bc79b26304f&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.png%3B+filename%3D%22model.png%22%3B&response-content-type=image%2Fpng&x-id=GetObject&Expires=1761216419&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2MTIxNjQxOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2FzLWJyaWRnZS54ZXRodWIuaGYuY28veGV0LWJyaWRnZS11cy82OGU2NzdjNTk0ZDNmMjBiYmVlY2YxM2MvN2NmZjZhZTAyMWQ3YzUxOGVlNGUyZmNiNzA0OTA1MTZhZDllNDk5OWVjNzVjNmE1ZGQxNjRjYzZjYTY0YWUzMCoifV19&Signature=YjrX1ZF%7EX1qw-m2nWOY8AxdSXwbrsidvlTZ5YWXZx3UPv0my0u68lWcpWIpIxzkGeWTtWPvlCfMcmnpmmwS2wHexorhgq9c7%7E3Ghw20evO0EMPvHBwP4vWYmXW8nHBqqqbw8Qy1pojDm9TvXV19O4-fCFxPi1aQ5FOTC2Kmn9gKxW%7EAN7vkWnfhU8QcCf18139hMbUvh9YoJ%7EesOWXoCFWgAbyz%7Eroajt5e3oM9b-IsU%7E2-UzMZ4%7EMA2MSOFmg487bhZDbr2IMD15-8O0jzWu3qyO3T1H06S-9kTdI%7EC6AYtXUY8YtSWKw%7EBzhARjXK6%7EuZ3c3kE1V7%7EdnLl1YM-2w__&Key-Pair-Id=K2L8F4GPSG1IFC"
 
17
  """
18
 
19
  DESCRIPTION = f"""\
20
+ ### Prot2Text-V2 Demo
21
 
22
  {PROTEIN_HERO}
23
 
 
45
  [
46
  "MAVVLPAVVEELLSEMAAAVQESARIPDEYLLSLKFLFGSSATQALDLVDRQSITLISSPSGRRVYQVLGSSSKTYTCLASCHYCSCPAFAFSVLRKSDSILCKHLLAVYLSQVMRTCQQLSVSDKQLTDILLMEKKQEA"
47
  ],
48
+ [
49
+ "MCYSANGNTFLIVDNTQKRIPEEKKPDFVRENVGDLDGVIFVELVDGKYFMDYYNRDGSMAAFCGNGARAFSQYLIDRGWIKEKEFTFLSRAGEIKVIVDDSIWVRMPGVSEKKEMKVDGYEGYFVVVGVPHFVMEVKGIDELDVEKLGRDLRYKTGANVDFYEVLPDRLKVRTYERGVERETKACGTGVTSVFVVYRDKTGAKEVKIQVPGGTLFLKEENGEIFLRGDVKRCSEE"
50
+ ],
51
+ [
52
+ "MTQEERFEQRIAQETAIEPQDWMPDAYRKTLIRQIGQHAHSEIVGMLPEGNWITRAPTLRRKAILLAKVQDEAGHGLYLYSAAETLGCAREDIYQKMLDGRMKYSSIFNYPTLSWADIGVIGWLVDGAAIVNQVALCRTSYGPYARAMVKICKEESFHQRQGFEACMALAQGSEAQKQMLQDAINRFWWPALMMFGPNDDNSPNSARSLTWKIKRFTNDELRQRFVDNTVPQVEMLGMTVPDPDLHFDTESGHYRFGEIDWQEFNEVINGRGICNQERLDAKRKAWEEGTWVREAALAHAQKQHARKVA"
53
+ ],
54
+ [
55
+ "MTTRMIILNGGSSAGKSGIVRCLQSVLPEPWLAFGVDSLIEAMPLKMQSAEGGIEFDADGGVSIGPEFRALEGAWAEGVVAMARAGARIIIDDVFLGGAAAQERWRSFVGDLDVLWVGVRCDGAVAEGRETARGDRVAGMAAKQAYVVHEGVEYDVEVDTTHKESIECAWAIAAHVVP"
56
+ ],
57
  ]
58
 
59
  MAX_MAX_NEW_TOKENS = 256
60
  DEFAULT_MAX_NEW_TOKENS = 100
61
 
62
 
63
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
64
 
65
  system_message = (
66
  "You are a scientific assistant specialized in protein function "
 
73
 
74
  esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
75
  llama_tokenizer = AutoTokenizer.from_pretrained(
76
+ pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct",
77
+ pad_token='<|reserved_special_token_0|>'
 
 
 
 
 
78
  )
79
+ model = AutoModelForCausalLM.from_pretrained('xiao-fei/Prot2Text-V2-11B-Instruct-hf',
80
+ trust_remote_code=True,
81
+ torch_dtype=torch.bfloat16,).to(device)
82
  model.eval()
83
 
84
 
85
  @spaces.GPU(duration=90)
86
+ def stream_response(
87
  message: str,
88
  max_new_tokens: int = 1024,
89
  do_sample: bool = False,
 
91
  top_p: float = 0.9,
92
  top_k: int = 50,
93
  repetition_penalty: float = 1.2,
94
+ ) -> Iterator[str]:
95
+
96
+
97
+ streamer = TextIteratorStreamer(llama_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
98
+
99
+ user_message = "Sequence embeddings: " + placeholder * (len(message)+2)
100
  tokenized_prompt = llama_tokenizer.apply_chat_template(
101
  [
102
  {"role": "system", "content": system_message},
103
+ {"role": "user", "content": user_message}
104
+ ],
105
+ add_generation_prompt=True,
106
+ tokenize=True,
107
+ return_tensors="pt",
108
+ return_dict=True
109
  )
110
+ tokenized_sequence = esm_tokenizer(
111
+ message,
112
+ return_tensors="pt"
113
+ )
114
+ model.eval()
115
+ generate_kwargs = dict(
116
+ inputs=tokenized_prompt["input_ids"].to(model.device),
117
+ attention_mask=tokenized_prompt["attention_mask"].to(model.device),
118
+ protein_input_ids=tokenized_sequence["input_ids"].to(model.device),
119
+ protein_attention_mask=tokenized_sequence["attention_mask"].to(model.device),
120
+ eos_token_id=128009,
121
+ pad_token_id=128002,
122
+ return_dict_in_generate=False,
123
+ num_beams=1,
124
+ # device=device,
125
+ streamer=streamer,
126
+ max_new_tokens=max_new_tokens,
127
+ do_sample=do_sample,
128
+ top_p=top_p,
129
+ top_k=top_k,
130
+ temperature=temperature,
131
+ repetition_penalty=repetition_penalty,
132
+ )
133
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
134
+ t.start()
135
 
136
+ outputs = []
137
+ for text in streamer:
138
+ outputs.append(text)
139
+ yield "".join(outputs)
140
 
141
 
142
  ChatHistory = List[Tuple[str, str]]
 
160
  conversation = history.copy()
161
  conversation.append((message, ""))
162
 
163
+ for partial_response in stream_response(
164
+ message=message,
165
+ max_new_tokens=max_new_tokens,
166
+ do_sample=do_sample,
167
+ temperature=temperature,
168
+ top_p=top_p,
169
+ top_k=top_k,
170
+ repetition_penalty=repetition_penalty,
171
+ ):
172
+ conversation[-1] = (message, partial_response)
173
+ snapshot = conversation.copy()
174
+ yield snapshot, snapshot, gr.update(value="")
 
 
 
 
175
 
176
 
177
  def clear_conversation():
178
  empty_history: ChatHistory = []
179
  return empty_history, empty_history, gr.update(value="")
180
 
181
+ theme = Soft(
182
+ primary_hue="slate",
183
+ secondary_hue="stone",
184
+ neutral_hue="gray",
185
+ )
186
 
187
  with gr.Blocks(theme=theme, css_paths="style.css", fill_height=True) as demo:
188
  with gr.Row(equal_height=True):
 
199
  gr.Markdown(DESCRIPTION)
200
  with gr.Column(scale=7, min_width=400, elem_classes="interaction-column"):
201
  history_state = gr.State([])
202
+
203
+
204
  chatbot = gr.Chatbot(
205
  label="Generated Function",
206
  height=350,
 
303
  )
304
  with gr.Accordion("Model & usage notes", open=False):
305
  gr.Markdown(
306
+ "- **Model stack**: Facebook ESM2 encoder + Llama 3.1 8B instruction-tuned decoder.\n"
307
  "- **Token budget**: the generator truncates after the configured `Max new tokens`.\n"
308
  "- **Attribution**: Outputs are predictions; validate experimentally before publication.\n"
 
309
  )
310
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
311