Chamin09 commited on
Commit
92b64f3
·
verified ·
1 Parent(s): 17f9952

Update models/llm_setup.py

Browse files
Files changed (1) hide show
  1. models/llm_setup.py +55 -61
models/llm_setup.py CHANGED
@@ -1,68 +1,62 @@
1
  import torch
2
- #from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
3
- from llama_index.llms.huggingface import HuggingFaceLLM
4
- #from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
5
- #from llama_index.llms.huggingface import HuggingFaceLLM
6
 
7
- def setup_llm(model_name: str = "microsoft/phi-3-mini-4k-instruct",
8
- context_window: int = 4096,
9
- max_new_tokens: int = 512):
10
- """Set up the language model for the CSV chatbot."""
11
 
12
- try:
13
- # Initialize LLM with the correct import
14
- llm = HuggingFaceInferenceAPI(
15
- model_name=model_name,
16
- tokenizer_name=model_name,
17
- context_window=context_window,
18
- max_new_tokens=max_new_tokens,
19
- generate_kwargs={"temperature": 0.7, "top_p": 0.95}
20
- )
21
- return llm
22
-
23
- except Exception as e:
24
- print(f"Error initializing HuggingFaceInferenceAPI: {e}")
25
-
26
- # Fallback to a simpler approach if needed
27
- from transformers import pipeline
28
-
29
  try:
30
- # Use a smaller model as fallback
31
- pipe = pipeline(
32
  "text-generation",
33
- model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
34
- torch_dtype="auto",
35
- device_map="auto",
 
36
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Create a simple wrapper to match LlamaIndex's expected interface
39
- class SimpleLLM:
40
- def complete(self, prompt):
41
- class Response:
42
- def __init__(self, text):
43
- self.text = text
44
-
45
- result = pipe(
46
- prompt,
47
- max_new_tokens=max_new_tokens,
48
- temperature=0.7,
49
- do_sample=True
50
- )
51
- generated_text = result[0]["generated_text"][len(prompt):]
52
- return Response(generated_text)
53
-
54
- return SimpleLLM()
55
-
56
- except Exception as e2:
57
- print(f"Fallback initialization also failed: {e2}")
58
-
59
- # Last resort - dummy LLM
60
- class DummyLLM:
61
- def complete(self, prompt):
62
- class Response:
63
- def __init__(self, text):
64
- self.text = text
65
-
66
- return Response("Model initialization failed. Please check logs.")
67
-
68
- return DummyLLM()
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
 
3
 
4
+ class SimpleTransformersLLM:
5
+ """A simple wrapper for Hugging Face Transformers models."""
 
 
6
 
7
+ def __init__(self, model_name="google/flan-t5-small"):
8
+ """Initialize with a small model that works on CPU."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  try:
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ self.pipe = pipeline(
12
  "text-generation",
13
+ model=model_name,
14
+ tokenizer=self.tokenizer,
15
+ max_length=512,
16
+ device_map="auto"
17
  )
18
+ except Exception as e:
19
+ print(f"Error initializing model: {e}")
20
+ self.pipe = None
21
+
22
+ def complete(self, prompt):
23
+ """Complete a prompt with the model."""
24
+ class Response:
25
+ def __init__(self, text):
26
+ self.text = text
27
+
28
+ if self.pipe is None:
29
+ return Response("Model initialization failed.")
30
+
31
+ try:
32
+ result = self.pipe(prompt, max_length=len(prompt) + 200, do_sample=True)
33
+ generated_text = result[0]["generated_text"]
34
 
35
+ # Extract only the new text (not including the prompt)
36
+ response_text = generated_text[len(prompt):].strip()
37
+ if not response_text:
38
+ response_text = "I couldn't generate a proper response."
39
+
40
+ return Response(response_text)
41
+ except Exception as e:
42
+ print(f"Error generating response: {e}")
43
+ return Response(f"Error generating response: {str(e)}")
44
+
45
+ def setup_llm():
46
+ """Set up a simple LLM that doesn't require API keys."""
47
+ try:
48
+ # Try with a very small model first
49
+ return SimpleTransformersLLM("google/flan-t5-small")
50
+ except Exception as e:
51
+ print(f"Error setting up LLM: {e}")
52
+
53
+ # Fallback to dummy LLM
54
+ class DummyLLM:
55
+ def complete(self, prompt):
56
+ class Response:
57
+ def __init__(self, text):
58
+ self.text = text
59
+
60
+ return Response("This is a dummy response. The actual model couldn't be loaded.")
61
+
62
+ return DummyLLM()