samsl commited on
Commit
7b8e959
·
1 Parent(s): 251e30d

testing oauth

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import itertools
2
  import json
 
3
  import tempfile
4
 
5
  import biotite.structure as bs
@@ -15,8 +16,8 @@ from biotite.database import rcsb
15
  from biotite.sequence import io as seqio
16
  from biotite.structure import filter_amino_acids, io, spread_residue_wise, to_sequence
17
  from gradio_molecule3d import Molecule3D
18
- from huggingface_hub import get_hf_file_metadata
19
- from huggingface_hub.utils import GatedRepoError
20
  from matplotlib.cm import ScalarMappable
21
  from matplotlib.colors import Normalize
22
 
@@ -117,7 +118,6 @@ def predict_rocketshp(
117
  structure_code: str | None,
118
  structure_file: str | None,
119
  chain_id: str | None,
120
- oauth_token: gr.OAuthToken | None = None,
121
  ):
122
  print(f"sequence text: {sequence}")
123
  print(f"sequence file: {sequence_file}")
@@ -125,12 +125,13 @@ def predict_rocketshp(
125
  print(f"structure file: {structure_file}")
126
  print(f"model variant: {model_variant}")
127
 
128
- if oauth_token is None:
129
- raise gr.Error("Please log in to use this Space")
130
- token_value = oauth_token.token
131
- check_permissions(token_value)
132
 
 
133
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
134
  is_sequence_model = "seq" in model_variant or "mini" in model_variant
135
 
136
  if is_sequence_model:
@@ -145,7 +146,7 @@ def predict_rocketshp(
145
  raise gr.Error("Sequence input is required for the selected model.")
146
 
147
  struct_features = None
148
- seq_features = load_sequence(sequence, device=device)
149
 
150
  else:
151
  if structure_file is None:
@@ -206,9 +207,6 @@ def predict_rocketshp(
206
  sequence = str(to_sequence(structure)[0][0])
207
  seq_features = load_sequence(sequence, device=device)
208
 
209
- # Load the model
210
- model = RocketSHP.load_from_checkpoint(model_variant).to(device)
211
-
212
  # Make predictions
213
  with torch.no_grad():
214
  try:
@@ -410,22 +408,36 @@ def visualize_network(
410
  return fig, bc_highlight, comm_highlight, out_cluster_file_name
411
 
412
 
413
- def check_permissions(token: str | None = None) -> None:
414
- if token is None:
415
- raise gr.Error("Please log in to use this Space")
 
 
416
  try:
417
- url = huggingface_hub.hf_hub_url(
 
 
 
 
418
  repo_id="EvolutionaryScale/esm3-sm-open-v1",
419
- repo_type="model",
420
- filename="config.json",
421
  )
422
- get_hf_file_metadata(url=url)
423
- return
 
424
  except GatedRepoError:
425
- raise gr.Error(
426
- "You must have access to ... to run this Space. Please go through the gating process and come back."
 
427
  )
428
 
 
 
 
 
 
 
429
 
430
  reps = [
431
  {
 
1
  import itertools
2
  import json
3
+ import os
4
  import tempfile
5
 
6
  import biotite.structure as bs
 
16
  from biotite.sequence import io as seqio
17
  from biotite.structure import filter_amino_acids, io, spread_residue_wise, to_sequence
18
  from gradio_molecule3d import Molecule3D
19
+ from huggingface_hub import HfApi
20
+ from huggingface_hub.utils import RepositoryNotFoundError, GatedRepoError
21
  from matplotlib.cm import ScalarMappable
22
  from matplotlib.colors import Normalize
23
 
 
118
  structure_code: str | None,
119
  structure_file: str | None,
120
  chain_id: str | None,
 
121
  ):
122
  print(f"sequence text: {sequence}")
123
  print(f"sequence file: {sequence_file}")
 
125
  print(f"structure file: {structure_file}")
126
  print(f"model variant: {model_variant}")
127
 
128
+ is_authorized, token = check_user_access()
129
+ if not is_authorized:
130
+ raise gr.Error("Failed to authorize repository access.")
 
131
 
132
+ # Load the model
133
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
134
+ model = RocketSHP.load_from_checkpoint(model_variant).to(device)
135
  is_sequence_model = "seq" in model_variant or "mini" in model_variant
136
 
137
  if is_sequence_model:
 
146
  raise gr.Error("Sequence input is required for the selected model.")
147
 
148
  struct_features = None
149
+ seq_features = load_sequence(sequence, device=device, HF_TOKEN=token)
150
 
151
  else:
152
  if structure_file is None:
 
207
  sequence = str(to_sequence(structure)[0][0])
208
  seq_features = load_sequence(sequence, device=device)
209
 
 
 
 
210
  # Make predictions
211
  with torch.no_grad():
212
  try:
 
408
  return fig, bc_highlight, comm_highlight, out_cluster_file_name
409
 
410
 
411
+ def check_user_access(profile: gr.OAuthProfile | None):
412
+ """Check if user is logged in and has access to private repo"""
413
+ if profile is None:
414
+ return False, "Please log in to use this Space"
415
+
416
  try:
417
+ # Try to access a file from your private repo
418
+ api = HfApi(token=profile.oauth_info.access_token)
419
+
420
+ # Test access by trying to get repo info
421
+ api.repo_info(
422
  repo_id="EvolutionaryScale/esm3-sm-open-v1",
423
+ repo_type="model", # or "dataset" or "space"
424
+ token=profile.oauth_info.access_token,
425
  )
426
+
427
+ return True, profile.oauth_info.access_token
428
+
429
  except GatedRepoError:
430
+ return (
431
+ False,
432
+ "You need to request access to the private repository at https://huggingface.co/username/private-repo-name",
433
  )
434
 
435
+ except RepositoryNotFoundError:
436
+ return False, "You don't have access to the required repository"
437
+
438
+ except Exception as e:
439
+ return False, f"Error checking access: {str(e)}"
440
+
441
 
442
  reps = [
443
  {