einstein / README.md
sabankara's picture
Update README.md (#11)
ae7a200
|
raw
history blame
2.42 kB
metadata
license: apache-2.0
metrics:
  - accuracy
pipeline_tag: text2text-generation
tags:
  - legal
  - text-generation-inference
datasets:
  - lgaalves/camel-ai-physics

Model Description

This model is trained on a large physics dataset, specifically “Special Relativity,” “Dark Matter,” “Black Holes,” “Quantum Mechanics,” “Plasma Physics,” “Particle Physics,” “Specific Theory,” “Nuclear Physics.” Provides expertise in "Atomic Physics," "Quantum Field Theory," "Gravitational Waves," "Electromagnetism," and "Chaotic Theory." It has the ability to provide information on a wide range of subjects, with the ability to answer physics questions under these headings.

from transformers import AutoTokenizer, GenerationConfig, T5ForConditionalGeneration
import torch
from peft import PeftModel
import time
import sys

repo_id = "sabankara/einstein"
model_name = 'google/flan-t5-small'
tokenizer = AutoTokenizer.from_pretrained(repo_id,device=0)

peft_model_base = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, local_files_only=True, device_map= {"":0})
peft_model = PeftModel.from_pretrained(peft_model_base, 
                                    repo_id, 
                                    torch_dtype=torch.bfloat16,
                                    is_trainable=False)


def print_character_by_character(text, delay=0.005, color_code="\033[36m"):
    for char in text:
        sys.stdout.write(f"{color_code}{char}\033[0m")
        sys.stdout.flush()
        time.sleep(delay)

def add_newline_after_punctuation(text):
    n=17
    words = text.split()
    updated_text = ""

    for i, word in enumerate(words):
        updated_text += word + " "
        if (i + 1) % n == 0:
            updated_text += "\n"

    return updated_text.strip()

#example
prompt ="what is the black hole?"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda(0)
model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(min_new_tokens= 100, max_new_tokens=512, num_beams=1, do_sample=True, top_p=0.6, top_k=0, temperature=0.4,repetition_penalty=2.5))

model_text_output = tokenizer.decode(model_outputs[0], skip_special_tokens=True)

albert_response = add_newline_after_punctuation(model_text_output)

print("\033[32mEintein:\033[0m")
print_character_by_character(albert_response)
sys.stdout.write("\n")