fabioantonini commited on
Commit
91967b1
·
verified ·
1 Parent(s): 53c4d05

Upload llm_utils.py

Browse files
Files changed (1) hide show
  1. llm_utils.py +101 -0
llm_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_utils.py
2
+
3
+ import os
4
+ import requests
5
+
6
+ # 1) HUGGING FACE INFERENCE API APPROACH
7
+ # ---------------------------------------
8
+ # This approach sends your prompt to the hosted Inference API on Hugging Face.
9
+ # You need:
10
+ # - A model endpoint, e.g. 'tiiuae/falcon-7b-instruct'
11
+ # - A valid HUGGINGFACEHUB_API_TOKEN with access to that model.
12
+ #
13
+ # Pros: no heavy model to download locally
14
+ # Cons: subject to model availability, rate limits, and does not run fully offline
15
+
16
+ HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "YOUR_API_TOKEN_HERE")
17
+ MODEL_ID = "tiiuae/falcon-7b-instruct"
18
+ API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
19
+
20
+ def get_llm_opinion_inference_api(prompt):
21
+ """
22
+ Queries the Hugging Face Inference API for a text generation model.
23
+ Returns the generated text as a string.
24
+ """
25
+ headers = {
26
+ "Authorization": f"Bearer {HF_API_TOKEN}",
27
+ "Content-Type": "application/json"
28
+ }
29
+ payload = {
30
+ "inputs": prompt,
31
+ "parameters": {
32
+ "max_new_tokens": 200,
33
+ "temperature": 0.5,
34
+ "top_p": 0.9,
35
+ "do_sample": True
36
+ }
37
+ }
38
+
39
+ response = requests.post(API_URL, headers=headers, json=payload)
40
+ if response.status_code != 200:
41
+ return f"Error: Hugging Face Inference API returned status {response.status_code}\n{response.text}"
42
+
43
+ # The Inference API returns an array of generated text(s)
44
+ data = response.json()
45
+ if isinstance(data, dict) and "error" in data:
46
+ return f"Error: {data['error']}"
47
+
48
+ # Typically, data[0]["generated_text"] holds the string
49
+ return data[0]["generated_text"]
50
+
51
+ # 2) LOCAL PIPELINE APPROACH
52
+ # --------------------------
53
+ # This approach loads a model locally via the Transformers library.
54
+ # This can be done on a Hugging Face Space if:
55
+ # - The model size fits the hardware resources (RAM/GPU)
56
+ # - The Space is configured to install transformers, etc.
57
+ # Pros: no external calls, faster for repeated queries
58
+ # Cons: potentially large downloads, memory usage
59
+ #
60
+ # If you want to use this approach, uncomment and adapt as needed:
61
+
62
+
63
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
64
+
65
+ def create_local_pipeline(model_id="tiiuae/falcon-7b-instruct"):
66
+ # Download and load the model locally
67
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_id,
70
+ device_map="auto" # or "cpu" if no GPU is available
71
+ )
72
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
73
+
74
+ # Example pipeline initialization (done once):
75
+ # generator = create_local_pipeline()
76
+
77
+ def get_llm_opinion_local(prompt, generator):
78
+ # Generate text from the local pipeline
79
+ outputs = generator(prompt, max_length=256, do_sample=True, temperature=0.5)
80
+ return outputs[0]["generated_text"]
81
+
82
+
83
+ # 3) WRAPPING LOGIC
84
+ # -----------------
85
+ # You can unify the approaches in a single function. For instance, if you want
86
+ # to prefer local inference if a pipeline is initialized, otherwise fallback to
87
+ # the Inference API:
88
+
89
+ def get_llm_opinion(prompt, generator=None):
90
+ """
91
+ High-level function to get the LLM's opinion.
92
+ If a local pipeline 'generator' is provided, use that.
93
+ Otherwise, fallback to the Hugging Face Inference API.
94
+ """
95
+ if generator is not None:
96
+ # local pipeline approach
97
+ outputs = generator(prompt, max_length=256, do_sample=True, temperature=0.5)
98
+ return outputs[0]["generated_text"]
99
+ else:
100
+ # inference API approach
101
+ return get_llm_opinion_inference_api(prompt)