File size: 3,699 Bytes
91967b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# llm_utils.py

import os
import requests

# 1) HUGGING FACE INFERENCE API APPROACH
# ---------------------------------------
# This approach sends your prompt to the hosted Inference API on Hugging Face.
# You need:
#   - A model endpoint, e.g. 'tiiuae/falcon-7b-instruct'
#   - A valid HUGGINGFACEHUB_API_TOKEN with access to that model.
#
# Pros: no heavy model to download locally
# Cons: subject to model availability, rate limits, and does not run fully offline

HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "YOUR_API_TOKEN_HERE")
MODEL_ID = "tiiuae/falcon-7b-instruct"
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"

def get_llm_opinion_inference_api(prompt):
    """

    Queries the Hugging Face Inference API for a text generation model.

    Returns the generated text as a string.

    """
    headers = {
        "Authorization": f"Bearer {HF_API_TOKEN}",
        "Content-Type": "application/json"
    }
    payload = {
        "inputs": prompt,
        "parameters": {
            "max_new_tokens": 200,
            "temperature": 0.5,
            "top_p": 0.9,
            "do_sample": True
        }
    }

    response = requests.post(API_URL, headers=headers, json=payload)
    if response.status_code != 200:
        return f"Error: Hugging Face Inference API returned status {response.status_code}\n{response.text}"

    # The Inference API returns an array of generated text(s)
    data = response.json()
    if isinstance(data, dict) and "error" in data:
        return f"Error: {data['error']}"
    
    # Typically, data[0]["generated_text"] holds the string
    return data[0]["generated_text"]

# 2) LOCAL PIPELINE APPROACH
# --------------------------
# This approach loads a model locally via the Transformers library.
# This can be done on a Hugging Face Space if:
#   - The model size fits the hardware resources (RAM/GPU)
#   - The Space is configured to install transformers, etc.
# Pros: no external calls, faster for repeated queries
# Cons: potentially large downloads, memory usage
#
# If you want to use this approach, uncomment and adapt as needed:


from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

def create_local_pipeline(model_id="tiiuae/falcon-7b-instruct"):
    # Download and load the model locally
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto"  # or "cpu" if no GPU is available
    )
    return pipeline("text-generation", model=model, tokenizer=tokenizer)

# Example pipeline initialization (done once):
# generator = create_local_pipeline()

def get_llm_opinion_local(prompt, generator):
    # Generate text from the local pipeline
    outputs = generator(prompt, max_length=256, do_sample=True, temperature=0.5)
    return outputs[0]["generated_text"]


# 3) WRAPPING LOGIC
# -----------------
# You can unify the approaches in a single function. For instance, if you want
# to prefer local inference if a pipeline is initialized, otherwise fallback to
# the Inference API:

def get_llm_opinion(prompt, generator=None):
    """

    High-level function to get the LLM's opinion.

    If a local pipeline 'generator' is provided, use that.

    Otherwise, fallback to the Hugging Face Inference API.

    """
    if generator is not None:
        # local pipeline approach
        outputs = generator(prompt, max_length=256, do_sample=True, temperature=0.5)
        return outputs[0]["generated_text"]
    else:
        # inference API approach
        return get_llm_opinion_inference_api(prompt)