VcRlAgent commited on
Commit
a1544bb
·
1 Parent(s): a94e830

Generator Refactor HF Inference Client API

Browse files
app/services/generator.py CHANGED
@@ -1,88 +1,77 @@
1
- """LLM generation service using Hugging Face Inference API"""
2
- import requests
3
- from typing import Dict, Any, Optional
 
4
  from app.config import settings
5
  from app.utils.logger import setup_logger
6
 
7
  logger = setup_logger(__name__)
8
 
 
9
  class GeneratorService:
10
- """Handles text generation using Hugging Face models"""
11
-
12
  def __init__(self):
13
- self.api_url = settings.HF_API_URL
14
- self.headers = {"Authorization": f"Bearer {settings.HF_TOKEN}"}
15
-
 
 
 
16
  def generate(
17
  self,
18
  prompt: str,
19
  max_tokens: int = 512,
20
- temperature: float = 0.7
21
  ) -> str:
22
- """Generate text using the LLM"""
23
- payload = {
24
- "inputs": prompt,
25
- "parameters": {
26
- "max_new_tokens": max_tokens,
27
- "temperature": temperature,
28
- "return_full_text": False
29
- }
30
- }
31
-
32
  try:
33
- logger.info("Calling Hugging Face API...")
34
- response = requests.post(
35
- self.api_url,
36
- headers=self.headers,
37
- json=payload,
38
- timeout=30
 
39
  )
40
- response.raise_for_status()
41
-
42
- result = response.json()
43
-
44
- # Handle different response formats
45
- if isinstance(result, list) and len(result) > 0:
46
- generated_text = result[0].get('generated_text', '')
47
- elif isinstance(result, dict):
48
- generated_text = result.get('generated_text', '')
49
- else:
50
- generated_text = str(result)
51
-
52
  logger.info("Generation successful")
 
53
  return generated_text.strip()
54
-
55
- except requests.exceptions.RequestException as e:
56
- logger.error(f"API request failed: {str(e)}")
57
- # Fallback to simple response
58
  return self._fallback_response(prompt)
59
-
60
  def _fallback_response(self, prompt: str) -> str:
61
- """Fallback response when API fails"""
62
- return "I apologize, but I'm unable to generate a response at the moment. Please try again later."
63
-
64
- def generate_rag_response(
65
- self,
66
- query: str,
67
- context: str
68
- ) -> str:
69
- """Generate response using RAG pattern"""
70
  prompt = self._build_rag_prompt(query, context)
71
  return self.generate(prompt)
72
-
73
  def _build_rag_prompt(self, query: str, context: str) -> str:
74
- """Build RAG prompt template"""
75
- prompt = f"""<s>[INST] You are WorkWise, an AI assistant specialized in analyzing Jira project data. Answer the user's question based on the provided context.
 
 
76
 
77
  Context:
78
  {context}
79
 
80
  User Question: {query}
81
 
82
- Provide a clear, concise answer based on the context. If the context doesn't contain enough information, say so. [/INST]</s>
 
 
83
 
84
- Answer:"""
85
- return prompt
86
 
87
  # Global instance
88
- generator = GeneratorService()
 
1
+ """LLM generation service using Hugging Face Inference Client SDK"""
2
+ import os
3
+ from typing import Optional
4
+ from huggingface_hub import InferenceClient
5
  from app.config import settings
6
  from app.utils.logger import setup_logger
7
 
8
  logger = setup_logger(__name__)
9
 
10
+
11
  class GeneratorService:
12
+ """Handles text generation using Hugging Face InferenceClient"""
13
+
14
  def __init__(self):
15
+ # Create a single reusable inference client
16
+ self.client = InferenceClient(api_key=settings.HF_TOKEN)
17
+
18
+ # Use model from settings or fallback
19
+ self.model = getattr(settings, "HF_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
20
+
21
  def generate(
22
  self,
23
  prompt: str,
24
  max_tokens: int = 512,
25
+ temperature: float = 0.7,
26
  ) -> str:
27
+ """Generate text using HF chat-completion API"""
28
+
 
 
 
 
 
 
 
 
29
  try:
30
+ logger.info(f"Calling HF InferenceClient (model={self.model})...")
31
+
32
+ completion = self.client.chat.completions.create(
33
+ model=self.model,
34
+ messages=[{"role": "user", "content": prompt}],
35
+ max_tokens=max_tokens,
36
+ temperature=temperature,
37
  )
38
+
39
+ generated_text = completion.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
40
  logger.info("Generation successful")
41
+
42
  return generated_text.strip()
43
+
44
+ except Exception as e:
45
+ logger.error(f"HF Generation failed: {str(e)}")
 
46
  return self._fallback_response(prompt)
47
+
48
  def _fallback_response(self, prompt: str) -> str:
49
+ """Fallback response when LLM API fails"""
50
+ return (
51
+ "I apologize, but I'm unable to generate a response at the moment. "
52
+ "Please try again later."
53
+ )
54
+
55
+ def generate_rag_response(self, query: str, context: str) -> str:
56
+ """Generate response using RAG-style prompt formatting"""
 
57
  prompt = self._build_rag_prompt(query, context)
58
  return self.generate(prompt)
59
+
60
  def _build_rag_prompt(self, query: str, context: str) -> str:
61
+ """Build WorkWise-style RAG prompt"""
62
+ return f"""
63
+ You are WorkWise, an AI assistant specialized in analyzing Jira project data.
64
+ Answer the user's question based only on the context.
65
 
66
  Context:
67
  {context}
68
 
69
  User Question: {query}
70
 
71
+ Provide a clear, concise answer.
72
+ If the context doesn't contain enough information, say so.
73
+ """.strip()
74
 
 
 
75
 
76
  # Global instance
77
+ generator = GeneratorService()
app/services/generator.py.legacyJSON ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM generation service using Hugging Face Inference API"""
2
+ import requests
3
+ from typing import Dict, Any, Optional
4
+ from app.config import settings
5
+ from app.utils.logger import setup_logger
6
+
7
+ logger = setup_logger(__name__)
8
+
9
+ class GeneratorService:
10
+ """Handles text generation using Hugging Face models"""
11
+
12
+ def __init__(self):
13
+ self.api_url = settings.HF_API_URL
14
+ self.headers = {"Authorization": f"Bearer {settings.HF_TOKEN}"}
15
+
16
+ def generate(
17
+ self,
18
+ prompt: str,
19
+ max_tokens: int = 512,
20
+ temperature: float = 0.7
21
+ ) -> str:
22
+ """Generate text using the LLM"""
23
+ payload = {
24
+ "inputs": prompt,
25
+ "parameters": {
26
+ "max_new_tokens": max_tokens,
27
+ "temperature": temperature,
28
+ "return_full_text": False
29
+ }
30
+ }
31
+
32
+ try:
33
+ logger.info("Calling Hugging Face API...")
34
+ response = requests.post(
35
+ self.api_url,
36
+ headers=self.headers,
37
+ json=payload,
38
+ timeout=30
39
+ )
40
+ response.raise_for_status()
41
+
42
+ result = response.json()
43
+
44
+ # Handle different response formats
45
+ if isinstance(result, list) and len(result) > 0:
46
+ generated_text = result[0].get('generated_text', '')
47
+ elif isinstance(result, dict):
48
+ generated_text = result.get('generated_text', '')
49
+ else:
50
+ generated_text = str(result)
51
+
52
+ logger.info("Generation successful")
53
+ return generated_text.strip()
54
+
55
+ except requests.exceptions.RequestException as e:
56
+ logger.error(f"API request failed: {str(e)}")
57
+ # Fallback to simple response
58
+ return self._fallback_response(prompt)
59
+
60
+ def _fallback_response(self, prompt: str) -> str:
61
+ """Fallback response when API fails"""
62
+ return "I apologize, but I'm unable to generate a response at the moment. Please try again later."
63
+
64
+ def generate_rag_response(
65
+ self,
66
+ query: str,
67
+ context: str
68
+ ) -> str:
69
+ """Generate response using RAG pattern"""
70
+ prompt = self._build_rag_prompt(query, context)
71
+ return self.generate(prompt)
72
+
73
+ def _build_rag_prompt(self, query: str, context: str) -> str:
74
+ """Build RAG prompt template"""
75
+ prompt = f"""<s>[INST] You are WorkWise, an AI assistant specialized in analyzing Jira project data. Answer the user's question based on the provided context.
76
+
77
+ Context:
78
+ {context}
79
+
80
+ User Question: {query}
81
+
82
+ Provide a clear, concise answer based on the context. If the context doesn't contain enough information, say so. [/INST]</s>
83
+
84
+ Answer:"""
85
+ return prompt
86
+
87
+ # Global instance
88
+ generator = GeneratorService()
requirements.txt CHANGED
@@ -1,8 +1,10 @@
 
1
  fastapi==0.109.0
2
  uvicorn[standard]==0.27.0
3
  python-dotenv==1.0.0
4
  python-multipart==0.0.6 # if you accept file uploads
5
 
 
6
  # === Data / utilities ===
7
  pandas==2.2.0
8
  numpy==1.26.3
 
1
+ huggingface-hub>=0.26.0
2
  fastapi==0.109.0
3
  uvicorn[standard]==0.27.0
4
  python-dotenv==1.0.0
5
  python-multipart==0.0.6 # if you accept file uploads
6
 
7
+
8
  # === Data / utilities ===
9
  pandas==2.2.0
10
  numpy==1.26.3