File size: 1,714 Bytes
3a93742
 
 
 
 
 
 
 
ca69070
 
 
3a93742
ca69070
 
 
3a93742
ca69070
 
 
 
 
3a93742
ca69070
 
 
 
 
 
 
 
 
3a93742
 
 
ca69070
3a93742
 
 
 
ca69070
3a93742
 
 
ca69070
 
 
 
 
 
 
3a93742
ca69070
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from config import PRIMARY_LLM, FALLBACK_LLM


def load_model():
    try:
        tokenizer = AutoTokenizer.from_pretrained(PRIMARY_LLM)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        model = AutoModelForCausalLM.from_pretrained(
            PRIMARY_LLM,
            device_map="auto",
            load_in_8bit=True
        )
        print(f"Loaded primary model: {PRIMARY_LLM}")
    except Exception as e:
        print(f"Primary model failed: {e}")
        print(f"Loading fallback: {FALLBACK_LLM}")
        
        tokenizer = AutoTokenizer.from_pretrained(FALLBACK_LLM, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        model = AutoModelForCausalLM.from_pretrained(
            FALLBACK_LLM,
            trust_remote_code=True,
            torch_dtype=torch.float16,
            device_map="auto"
        )

    return tokenizer, model


tokenizer, model = load_model()


def generate(prompt, max_tokens=400):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    if prompt in generated_text:
        generated_text = generated_text.replace(prompt, "").strip()
    
    return generated_text