Dheeraj-13 commited on
Commit
d5cf328
·
1 Parent(s): 267abd7

Fix broken generate.py and implement lazy loading

Browse files
Files changed (1) hide show
  1. services/rag/generate.py +51 -65
services/rag/generate.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  from typing import List, Dict
3
  from openai import OpenAI
4
  from ..observability.langfuse_client import observe
 
5
 
6
  SYSTEM_PROMPT = """You are a grounded knowledge assistant.
7
  Your goal is to answer the user's question using ONLY the provided context.
@@ -14,12 +15,15 @@ Rules:
14
  5. Be concise and direct.
15
  """
16
 
 
 
 
17
  class GeneratorService:
18
  def __init__(self):
19
  # Initialize OpenAI if key exists
20
  self.openai_client = None
21
- self.openai_model = "gpt-5-nano" # Default
22
- # Initialize OpenAI if key exists
23
  api_key = os.getenv("OPENAI_API_KEY")
24
  if api_key:
25
  self.openai_client = OpenAI(api_key=api_key)
@@ -38,90 +42,72 @@ class GeneratorService:
38
  def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
39
  """
40
  backend: 'openai' or 'local'
 
41
  """
42
  context = self._format_context(context_chunks)
43
 
44
- if backend == "openai" and self.openai_client is None:
45
- return "Error: OpenAI backend selected but OPENAI_API_KEY not found. Please switch to Local or set key."
46
-
47
- if use_openai:
48
- # OpenAI Logic
49
- input_messages = [
50
- {"role": "system", "content": SYSTEM_PROMPT},
51
- {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
52
- ]
53
 
54
- if "gpt-5-" in self.openai_model or "nano" in self.openai_model:
55
- try:
56
- response = self.openai_client.responses.create(
57
- model=self.openai_model,
58
- input=input_messages,
59
- reasoning={"effort": "medium"},
60
- text={"verbosity": "medium"},
61
- max_output_tokens=1000
62
- )
63
- return response.output_text
64
- except Exception as e:
65
- return f"OpenAI Error: {str(e)}"
66
- else:
67
  response = self.openai_client.chat.completions.create(
68
  model=self.openai_model,
69
- messages=input_messages,
70
- temperature=0.1
 
 
71
  )
72
  return response.choices[0].message.content
 
 
 
73
  else:
74
- # Local Logic (Mistral / ZeroGPU)
75
- self._ensure_local_loaded()
 
 
76
 
77
- # Mistral expects standard chat messages usually
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  messages = [
79
  {"role": "system", "content": SYSTEM_PROMPT},
80
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
81
  ]
82
 
83
- # ZeroGPU / Spaces Handling
84
- def run_pipeline(msgs):
85
- # Modern pipelines handle list of messages automatically applying chat template
86
- outputs = self.local_pipeline(
87
- msgs,
88
- max_new_tokens=512, # Increased for Mistral
89
  do_sample=True,
90
  temperature=0.1,
91
  top_k=50,
92
  top_p=0.95
93
  )
94
- # Output is usually a list of dicts.
95
- # For chat pipeline: [{'generated_text': [..., {'role': 'assistant', 'content': '...'}]}]
96
- # OR just string if text-generation vs text-generation-chat?
97
- # Default pipeline "text-generation" with list input usually returns:
98
- # [{'generated_text': [{'role': 'user', ...}, {'role': 'assistant', 'content': 'Response'}]}]
99
- return outputs[0]['generated_text']
100
-
101
- # Try to decorate with spaces.GPU
102
- try:
103
- import spaces
104
- print("ZeroGPU enabled for this generation.")
105
- run_pipeline = spaces.GPU(run_pipeline)
106
- except ImportError:
107
- pass
108
  except Exception as e:
109
- print(f"Could not use ZeroGPU: {e}")
110
-
111
- result = run_pipeline(messages)
112
-
113
- # Parse result (Transformers pipeline behavior varies by version/call)
114
- # If result is a string (rare for chat list input), return it.
115
- # If result is a list of messages (standard for chat), extract last content.
116
- if isinstance(result, list):
117
- # Check if it's the full conversation
118
- last_msg = result[-1]
119
- if last_msg.get('role') == 'assistant':
120
- return last_msg['content']
121
- else:
122
- # Fallback
123
- return str(result)
124
- return str(result)
125
 
126
  _shared_generator = None
127
 
 
2
  from typing import List, Dict
3
  from openai import OpenAI
4
  from ..observability.langfuse_client import observe
5
+ import torch
6
 
7
  SYSTEM_PROMPT = """You are a grounded knowledge assistant.
8
  Your goal is to answer the user's question using ONLY the provided context.
 
15
  5. Be concise and direct.
16
  """
17
 
18
+ # Global variable for lazy loading on the worker node
19
+ _local_pipeline = None
20
+
21
  class GeneratorService:
22
  def __init__(self):
23
  # Initialize OpenAI if key exists
24
  self.openai_client = None
25
+ self.openai_model = "gpt-4o-mini" # User reported gpt-5 errors, safer default? Or keep logic.
26
+
27
  api_key = os.getenv("OPENAI_API_KEY")
28
  if api_key:
29
  self.openai_client = OpenAI(api_key=api_key)
 
42
  def generate(self, query: str, context_chunks: List[Dict], backend: str = "openai") -> str:
43
  """
44
  backend: 'openai' or 'local'
45
+ NOTE: This method must be running in a @spaces.GPU context if backend='local'.
46
  """
47
  context = self._format_context(context_chunks)
48
 
49
+ # Check explicit backend choice
50
+ if backend == "openai":
51
+ if self.openai_client is None:
52
+ return "Error: OpenAI backend selected but OPENAI_API_KEY not found. Please switch to Local or set key."
 
 
 
 
 
53
 
54
+ # OpenAI Generation
55
+ try:
56
+ # Basic Chat Completion
 
 
 
 
 
 
 
 
 
 
57
  response = self.openai_client.chat.completions.create(
58
  model=self.openai_model,
59
+ messages=[
60
+ {"role": "system", "content": SYSTEM_PROMPT},
61
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
62
+ ]
63
  )
64
  return response.choices[0].message.content
65
+ except Exception as e:
66
+ return f"OpenAI Error: {str(e)}"
67
+
68
  else:
69
+ # Local Generation (Mistral)
70
+ # This block expects to be running on ZeroGPU (enforced by app.py decorator)
71
+
72
+ global _local_pipeline
73
 
74
+ # Lazy Load the model here (on the GPU node)
75
+ if _local_pipeline is None:
76
+ print("Loading local Mistral-7B model (Lazy Load)...")
77
+ try:
78
+ from transformers import pipeline
79
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
80
+ _local_pipeline = pipeline(
81
+ "text-generation",
82
+ model=model_id,
83
+ torch_dtype=torch.float16,
84
+ device_map="auto"
85
+ )
86
+ except Exception as e:
87
+ return f"Failed to load local model: {e}"
88
+
89
+ # Prepare messages
90
  messages = [
91
  {"role": "system", "content": SYSTEM_PROMPT},
92
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
93
  ]
94
 
95
+ try:
96
+ outputs = _local_pipeline(
97
+ messages,
98
+ max_new_tokens=512,
 
 
99
  do_sample=True,
100
  temperature=0.1,
101
  top_k=50,
102
  top_p=0.95
103
  )
104
+ # Parse output
105
+ result = outputs[0]['generated_text']
106
+ if isinstance(result, list):
107
+ return result[-1]['content']
108
+ return str(result)
 
 
 
 
 
 
 
 
 
109
  except Exception as e:
110
+ return f"Generation Error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  _shared_generator = None
113