| | --- |
| | license: mit |
| | datasets: |
| | - universeTBD/arxiv-astro-abstracts-all |
| | language: |
| | - en |
| | metrics: |
| | - perplexity |
| | pipeline_tag: text-generation |
| | tags: |
| | - llama-2 |
| | - astronomy |
| | - astrophysics |
| | - arxiv |
| | inference: false |
| | --- |
| | |
| | <p><h1>AstroLLaMA</h1></p> |
| |
|
| | **Play with the model in our Hugging Face space!** https://huggingface.co/spaces/universeTBD/astrollama |
| |
|
| | <p align="center"> |
| | <img src="https://huggingface.co/universeTBD/astrollama/resolve/main/images/astrollama-logo.png" alt="AstroLLaMA" width="500px"/> |
| | </p> |
| |
|
| | ## Loading the model |
| |
|
| | ```python |
| | from transformers import AutoModelForCausalLM |
| | from transformers import AutoTokenizer |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | pretrained_model_name_or_path="universeTBD/astrollama" |
| | ) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | pretrained_model_name_or_path="universeTBD/astrollama", |
| | device_map="auto", |
| | ) |
| | ``` |
| |
|
| | ## Generating text from a prompt |
| |
|
| | ```python |
| | import torch |
| | from transformers import pipeline |
| | |
| | generator = pipeline( |
| | task="text-generation", |
| | model=model, |
| | tokenizer=tokenizer, |
| | device_map="auto" |
| | ) |
| | |
| | # Taken from https://arxiv.org/abs/2308.12823 |
| | prompt = "In this letter, we report the discovery of the highest redshift, " \ |
| | "heavily obscured, radio-loud QSO candidate selected using JWST NIRCam/MIRI, " \ |
| | "mid-IR, sub-mm, and radio imaging in the COSMOS-Web field. " |
| | |
| | # For reproducibility |
| | torch.manual_seed(42) |
| | |
| | generated_text = generator( |
| | prompt, |
| | do_sample=True, |
| | max_length=512 |
| | ) |
| | ``` |
| |
|
| | ## Embedding text with AstroLLaMA |
| |
|
| | ```python |
| | texts = [ |
| | "Abstract 1", |
| | "Abstract 2" |
| | ] |
| | inputs = tokenizer( |
| | texts, |
| | return_tensors="pt", |
| | return_token_type_ids=False, |
| | padding=True, |
| | truncation=True, |
| | max_length=4096 |
| | ) |
| | inputs.to(model.device) |
| | outputs = model(**inputs, output_hidden_states=True) |
| | |
| | # Last layer of the hidden states. Get average embedding of all tokens |
| | embeddings = outputs["hidden_states"][-1][:, 1:, ...].mean(1).detach().cpu().numpy() |
| | ``` |