Annessha18 commited on
Commit
d00a76d
·
verified ·
1 Parent(s): f236646

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +90 -0
model.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Callable
3
+
4
+ from smolagents import LiteLLMModel
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
+ from functools import lru_cache
7
+ import time
8
+ import re
9
+ from litellm import RateLimitError
10
+
11
+
12
+ class LocalTransformersModel:
13
+ def __init__(self, model_id: str, **kwargs):
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
16
+ self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)
17
+
18
+ def __call__(self, prompt: str, **kwargs):
19
+ outputs = self.pipeline(prompt, **kwargs)
20
+ return outputs[0]["generated_text"]
21
+
22
+ class WrapperLiteLLMModel(LiteLLMModel):
23
+ def __call__(self, messages, **kwargs):
24
+ max_retry = 5
25
+ for attempt in range(max_retry):
26
+ try:
27
+ return super().__call__(messages, **kwargs)
28
+ except RateLimitError as e:
29
+ print(f"RateLimitError (attempt {attempt+1}/{max_retry})")
30
+
31
+ # Try to extract retry time from the exception string
32
+ match = re.search(r'"retryDelay": ?"(\d+)s"', str(e))
33
+ retry_seconds = int(match.group(1)) if match else 50
34
+
35
+ print(f"Sleeping for {retry_seconds} seconds before retrying...")
36
+ time.sleep(retry_seconds)
37
+
38
+ raise RateLimitError(f"Rate limit exceeded after {max_retry} retries.")
39
+
40
+ @lru_cache(maxsize=1)
41
+ def get_lite_llm_model(model_id: str, **kwargs) -> WrapperLiteLLMModel:
42
+ """
43
+ Returns a LiteLLM model instance.
44
+
45
+ Args:
46
+ model_id (str): The model identifier.
47
+ **kwargs: Additional keyword arguments for the model.
48
+
49
+ Returns:
50
+ LiteLLMModel: LiteLLM model instance.
51
+ """
52
+ return WrapperLiteLLMModel(model_id=model_id, api_key=os.getenv("GEMINI_API"), **kwargs)
53
+
54
+
55
+ @lru_cache(maxsize=1)
56
+ def get_local_model(model_id: str, **kwargs) -> LocalTransformersModel:
57
+ """
58
+ Returns a Local Transformer model.
59
+
60
+ Args:
61
+ model_id (str): The model identifier.
62
+ **kwargs: Additional keyword arguments for the model.
63
+
64
+ Returns:
65
+ LocalTransformersModel: LiteLLM model instance.
66
+ """
67
+ return LocalTransformersModel(model_id=model_id, **kwargs)
68
+
69
+
70
+ def get_model(model_type: str, model_id: str, **kwargs) -> Any:
71
+ """
72
+ Returns a model instance based on the specified type.
73
+
74
+ Args:
75
+ model_type (str): The type of the model (e.g., 'HfApiModel').
76
+ model_id (str): The model identifier.
77
+ **kwargs: Additional keyword arguments for the model.
78
+
79
+ Returns:
80
+ Any: Model instance of the specified type.
81
+ """
82
+ models: dict[str, Callable[..., Any]] = {
83
+ "LiteLLMModel": get_lite_llm_model,
84
+ "LocalTransformersModel": get_local_model,
85
+ }
86
+
87
+ if model_type not in models:
88
+ raise ValueError(f"Unknown model type: {model_type}")
89
+
90
+ return models[model_type](model_id, **kwargs)