rahulholla1/stock-analysis
Viewer • Updated • 948 • 5 • 4
How to use vdpappu/lora_stock_analysis with PEFT:
from peft import PeftModel
from transformers import AutoModelForCausalLM
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
model = PeftModel.from_pretrained(base_model, "vdpappu/lora_stock_analysis")A Gemma-2b finetuned LoRA trained on science Q&A
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from peft import PeftModel
from typing import Optional
import time
import os
def generate_prompt(input_text: str, instruction: Optional[str] = None) -> str:
text = f"### Question: {input_text}\n\n### Answer: "
if instruction:
text = f"### Instruction: {instruction}\n\n{text}"
return text
huggingface_token = os.environ.get('HUGGINGFACE_TOKEN')
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", token=huggingface_token)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token=huggingface_token)
lora_model = PeftModel.from_pretrained(base_model, "vdpappu/lora_stock_analysis")
merged_model = lora_model.merge_and_unload()
eos_token = '<eos>'
eos_token_id = tokenizer.encode(eos_token, add_special_tokens=False)[-1]
generation_config = GenerationConfig(
eos_token_id=tokenizer.eos_token_id,
min_length=5,
max_length=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.5,
no_repeat_ngram_size=3,
early_stopping=True
)
question = """Assume the role as a seasoned stock option analyst with a strong track record in dissecting intricate option data to discern valuable
insights into stock sentiment. Proficient in utilizing advanced statistical models and data visualization techniques to forecast
market trends and make informed trading decisions. Adept at interpreting option Greeks, implied volatility, .. """
prompt = generate_prompt(input_text=question)
with torch.no_grad():
inputs = tokenizer(prompt, return_tensors="pt")
output = merged_model.generate(**inputs, generation_config=generation_config)
response = tokenizer.decode(output[0], skip_special_tokens=True)
print(response)
Base model
google/gemma-2b