Rithankoushik commited on
Commit
18be1bf
·
verified ·
1 Parent(s): a2a6945

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +62 -255
inference.py CHANGED
@@ -1,282 +1,89 @@
1
- import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
- import os
5
- import json
6
- import time
7
  import re
 
 
8
  import json5
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
- import torch
11
- import streamlit as st
12
 
13
- @st.cache_resource(show_spinner="Loading model and tokenizer from Hugging Face Hub...")
14
- def load_model_and_tokenizer():
15
- MODEL_REPO = "Rithankoushik/job-parser-model-qwen" # your HF repo
16
 
17
- tokenizer = AutoTokenizer.from_pretrained(
18
- MODEL_REPO,
19
- trust_remote_code=True,
20
- )
21
- model = AutoModelForCausalLM.from_pretrained(
22
- MODEL_REPO,
 
 
23
  trust_remote_code=True,
24
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
25
  device_map="auto"
26
  )
 
 
 
27
 
28
 
 
 
 
 
 
29
 
30
- return tokenizer, model
31
-
32
-
33
- tokenizer, model = load_model_and_tokenizer()
34
-
35
- def extract_json_from_output(text):
36
- # Improved JSON extraction: find first '{' and match until the closing '}'
37
- start = text.find('{')
38
- if start == -1:
39
- return text
40
- stack = []
41
- for i in range(start, len(text)):
42
- if text[i] == '{':
43
- stack.append('{')
44
- elif text[i] == '}':
45
- stack.pop()
46
- if not stack:
47
- return text[start:i+1]
48
- # fallback if no matching closing brace found
49
- return text[start:]
50
-
51
- @st.cache_data
52
- def get_static_prompt_parts():
53
- system_prompt = (
54
- "You are a highly accurate JSON extractor for job descriptions. "
55
- "Your ONLY task is to extract what is explicitly mentioned in the job description. "
56
- "Do NOT guess or infer. If a field is not present in the job description, return an empty value for it. "
57
- "Always follow the provided JSON schema. Return ONLY the raw JSON object, with no additional text or formatting. "
58
- "Avoid hallucinations. Do not fabricate emails, phone numbers, websites, salaries, or skills that are not clearly mentioned. "
59
- )
60
-
61
- json_schema = """{
62
- "job_titles": [],
63
- "organization": { "employers": [], "websites": [] },
64
- "job_contact_details": { "email_address": [], "phone_number": [], "websites": [] },
65
- "location": { "hiring": [], "org_location": [] },
66
- "employment_details": { "employment_type": [], "work_mode": [] },
67
- "compensation": {
68
- "salary": [
69
- {
70
- "amount_in_text": "",
71
- "time_frequency": "",
72
- "parsed": { "min": "", "max": "", "currency": "" }
73
- }
74
- ],
75
- "benefits": []
76
- },
77
- "technical_skills": [ { "skill_name": "" } ],
78
- "soft_skills": [],
79
- "work_experience": {
80
- "min_in_years": null,
81
- "max_in_years": null,
82
- "role_experience": [
83
- { "min_in_years":null, "max_in_years":null, "skill": "" }
84
- ],
85
- "skill_experience": [
86
- { "min_in_years":null, "max_in_years":null, "skill": "" }
87
- ]
88
- },
89
- "qualifications": [
90
- { "qualification": [], "specilization": [] }
91
- ],
92
- "certifications": [],
93
- "languages": []
94
- }"""
95
- example_jd = """Job Title: Sustainability Analyst
96
- Company: HelioCore Energy GmbH
97
- Location:
98
-
99
- Hiring for: Berlin, Germany
100
-
101
- Org HQ: Berlin, Germany
102
-
103
- Employment Type: Full-time
104
- Work Mode: Hybrid (3 days onsite, 2 remote)
105
-
106
- Overview:
107
- HelioCore Energy GmbH is at the forefront of Europe's green transition, delivering scalable renewable energy projects across solar, wind, and hydrogen.
108
- As a Sustainability Analyst, you will work with our ESG, operations, and strategy teams to measure, improve, and report our sustainability performance while staying compliant with EU regulations.
109
-
110
- Key Responsibilities:
111
-
112
- Collect and analyze sustainability KPIs and ESG metrics from internal teams and partners.
113
-
114
- Create dashboards and reports aligned with CSRD and EU Taxonomy compliance.
115
-
116
- Collaborate with engineering teams to assess environmental impact of ongoing projects.
117
 
118
- Contribute to corporate sustainability strategy and annual disclosures.
119
-
120
- Benchmark company initiatives against global sustainability standards (GRI, SASB).
121
-
122
- Qualifications & Requirements:
123
-
124
- Bachelor's degree in Environmental Science, Sustainability, Economics, or related field.
125
-
126
- Up to 2 years of experience in sustainability reporting or ESG analytics.
127
-
128
- Proficiency in Excel, Power BI, or similar data tools is a plus.
129
-
130
- Familiarity with EU climate policy and frameworks.
131
-
132
- Certifications:
133
-
134
- GRI Certified Sustainability Professional (preferred)
135
-
136
- Languages:
137
-
138
- English (Fluent)
139
- German
140
- Compensation & Benefits:
141
- Salary: €3,000 - €3,600 per month
142
- Benefits: Green mobility stipend, learning budget, hybrid work flexibility, subsidized lunches, gym membership.
143
- Contact Information:
144
- Email: careers@heliocore.de"""
145
-
146
- example_json_output = """{
147
- "job_titles": ["Sustainability Analyst"],
148
- "organization": {
149
- "employers": ["HelioCore Energy GmbH"],
150
- "websites": []
151
- },
152
- "job_contact_details": {
153
- "email_address": ["careers@heliocore.de"],
154
- "phone_number": [],
155
- "websites": []
156
- },
157
- "location": {
158
- "hiring": ["Berlin, Germany"],
159
- "org_location": ["Berlin, Germany"]
160
- },
161
- "employment_details": {
162
- "employment_type": ["Full-time"],
163
- "work_mode": ["Hybrid"]
164
- },
165
- "compensation": {
166
- "salary": [
167
- {
168
- "amount_in_text": "€3,000 - €3,600 per month",
169
- "time_frequency": "monthly",
170
- "parsed": {
171
- "min": "3000",
172
- "max": "3600",
173
- "currency": "EUR"
174
- }
175
- }
176
- ],
177
- "benefits": [
178
- "Green mobility stipend",
179
- "Learning budget",
180
- "Hybrid work flexibility",
181
- "Subsidized lunches",
182
- "Gym membership"
183
- ]
184
- },
185
- "technical_skills": [
186
- {"skill_name": "Sustainability reporting"},
187
- {"skill_name": "ESG metrics"},
188
- {"skill_name": "Data visualization"},
189
- {"skill_name": "EU Taxonomy"},
190
- {"skill_name": "Environmental impact analysis"},
191
- {"skill_name": "Power BI"},
192
- {"skill_name": "Excel"},
193
- {"skill_name": "Carbon footprint modeling"}
194
- ],
195
- "soft_skills": [
196
- "Analytical thinking",
197
- "Communication",
198
- "Attention to detail",
199
- "Team collaboration",
200
- "Problem-solving"
201
- ],
202
- "work_experience": {
203
- "min_in_years": 0,
204
- "max_in_years": 2,
205
- "role_experience": [
206
- {
207
- "min_in_years": 0,
208
- "max_in_years": 2,
209
- "skill": "Sustainability analytics"
210
- }
211
- ],
212
- "skill_experience": [
213
- {
214
- "min_in_years": 0,
215
- "max_in_years": 2,
216
- "skill": "ESG frameworks"
217
- },
218
- {
219
- "min_in_years": 0,
220
- "max_in_years": 1,
221
- "skill": "Dashboarding"
222
- }
223
- ]
224
- },
225
- "qualifications": [
226
- {
227
- "qualification": ["Bachelor's Degree"],
228
- "specilization": ["Environmental Science", "Sustainability", "Economics"]
229
- }
230
- ],
231
- "certifications": ["GRI Certified Sustainability Professional"],
232
- "languages": ["English", "German"]
233
- }"""
234
-
235
-
236
- return system_prompt, json_schema, example_jd, example_json_output
237
 
238
 
239
  def infer_from_text(jd_text: str):
 
240
  start_time = time.time()
241
 
242
- system_prompt, json_schema, example_jd, example_json_output = get_static_prompt_parts()
243
-
244
- # Build user prompt only (changing part)
245
- user_prompt = f"""
246
-
247
- Now, perform the same task on the following new job description.
248
 
249
- New Job Description to be parsed:
250
- ---
251
- {jd_text}
252
- ---
253
 
254
- JSON Schema to follow:
255
- ---
256
- {json_schema}
257
- ---
258
- """
259
  messages = [
260
  {"role": "system", "content": system_prompt},
261
  {"role": "user", "content": user_prompt}
262
  ]
263
 
264
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
265
- inputs = tokenizer(prompt, return_tensors="pt")
266
- device = model.device if hasattr(model, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu")
267
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
268
 
269
- with torch.no_grad():
270
- outputs = model.generate(**inputs, max_new_tokens=800, do_sample=False)
271
- raw_response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
272
- cleaned = extract_json_from_output(raw_response).replace("None", "null").strip()
273
 
274
- try:
275
- parsed = json5.loads(cleaned)
276
- except Exception:
277
- try:
278
- parsed = json5.loads(cleaned)
279
- except Exception:
280
- return raw_response, round(time.time() - start_time, 2)
281
-
282
- return json.dumps(parsed, indent=2), round(time.time() - start_time, 2)
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
 
 
 
2
  import re
3
+ import time
4
+ import json
5
  import json5
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from peft import PeftModel
 
8
 
9
+ # Model paths
10
+ base_model_id = "Qwen/Qwen3-0.6B"
11
+ lora_model_id = "Rithankoushik/Qwen-0.6-Job-parser-Model"
12
 
13
+ # Load tokenizer
14
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+ tokenizer.pad_token_id = tokenizer.eos_token_id # ✅ critical fix
17
+
18
+ # Load model + LoRA
19
+ base_model = AutoModelForCausalLM.from_pretrained(
20
+ base_model_id,
21
  trust_remote_code=True,
22
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
  device_map="auto"
24
  )
25
+ model = PeftModel.from_pretrained(base_model, lora_model_id, device_map="auto")
26
+ model = model.merge_and_unload()
27
+ model.eval()
28
 
29
 
30
+ def extract_and_clean_json(text):
31
+ """Extract JSON from LLM output, even if extra text is present."""
32
+ matches = re.findall(r"\{[\s\S]*\}", text)
33
+ if not matches:
34
+ return None
35
 
36
+ json_str = matches[0] # take first JSON
37
+ json_str = json_str.replace("None", "null")
38
+ json_str = json_str.replace("True", "true").replace("False", "false")
39
+ json_str = re.sub(r",(\s*[}\]])", r"\1", json_str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ try:
42
+ return json5.loads(json_str)
43
+ except Exception as e:
44
+ print(f"JSON parse error: {e}")
45
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def infer_from_text(jd_text: str):
49
+ """Runs inference on a job description."""
50
  start_time = time.time()
51
 
52
+ system_prompt = "Extract structured information from the following job description and return it as JSON."
 
 
 
 
 
53
 
54
+ user_prompt = f"Job Description:\n{jd_text}"
 
 
 
55
 
 
 
 
 
 
56
  messages = [
57
  {"role": "system", "content": system_prompt},
58
  {"role": "user", "content": user_prompt}
59
  ]
60
 
61
+ # safer way
62
+ prompt = tokenizer.apply_chat_template(
63
+ messages,
64
+ tokenize=False,
65
+ add_generation_prompt=True
66
+ )
67
 
68
+ raw_inputs = tokenizer(prompt, return_tensors="pt")
69
+ device = model.device
70
+ inputs = {k: v.to(device) for k, v in raw_inputs.items()}
 
71
 
72
+ with torch.no_grad():
73
+ out = model.generate(
74
+ **inputs,
75
+ max_new_tokens=1000,
76
+ do_sample=False,
77
+ temperature=0,
78
+ pad_token_id=tokenizer.pad_token_id
79
+ )
80
+
81
+ gen_tokens = out[0][inputs["input_ids"].shape[1]:]
82
+ response_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
83
+ duration = round(time.time() - start_time, 2)
84
+
85
+ parsed = extract_and_clean_json(response_text)
86
+ if parsed is not None:
87
+ return json.dumps(parsed, indent=2), duration
88
+
89
+ return response_text, duration