Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -79,28 +79,36 @@ def generate_single_model_output(model, tokenizer, prompt, max_length=300, tempe
|
|
| 79 |
return tokenizer.decode(output[0], skip_special_tokens=True).strip()
|
| 80 |
|
| 81 |
# === ANALYSIS WITH FLAN-T5 ===
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def analyze_with_cpu_model(raw_outputs, zero_shot_injury):
|
| 83 |
-
|
|
|
|
| 84 |
for i, text in enumerate(raw_outputs):
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
prompt = (
|
| 88 |
-
f"The following are
|
| 89 |
-
f"{summary}\n"
|
| 90 |
-
f"A separate
|
| 91 |
-
f"Please analyze all
|
| 92 |
-
f"
|
| 93 |
-
f"
|
| 94 |
-
f"Return only in the format:\n"
|
| 95 |
-
f"Cause of Accident: ...\n"
|
| 96 |
-
f"Degree of Injury: ..."
|
| 97 |
)
|
| 98 |
|
| 99 |
inputs = flan_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
|
| 100 |
with torch.no_grad():
|
| 101 |
output = flan_model.generate(
|
| 102 |
**inputs,
|
| 103 |
-
max_length=
|
| 104 |
temperature=0.5,
|
| 105 |
top_p=0.9,
|
| 106 |
do_sample=True
|
|
|
|
| 79 |
return tokenizer.decode(output[0], skip_special_tokens=True).strip()
|
| 80 |
|
| 81 |
# === ANALYSIS WITH FLAN-T5 ===
|
| 82 |
+
def extract_json_only(text):
|
| 83 |
+
"""Extract just the first JSON object from model text output."""
|
| 84 |
+
pattern = r'\{(?:[^{}]|"[^"]*")*\}'
|
| 85 |
+
matches = re.findall(pattern, text, re.DOTALL)
|
| 86 |
+
return matches[0] if matches else ""
|
| 87 |
+
|
| 88 |
def analyze_with_cpu_model(raw_outputs, zero_shot_injury):
|
| 89 |
+
# Only extract JSON from each model output
|
| 90 |
+
json_blobs = []
|
| 91 |
for i, text in enumerate(raw_outputs):
|
| 92 |
+
json_part = extract_json_only(text)
|
| 93 |
+
if json_part:
|
| 94 |
+
json_blobs.append(f"Model {i+1} JSON:\n{json_part}")
|
| 95 |
+
|
| 96 |
+
summary = "\n\n".join(json_blobs)
|
| 97 |
|
| 98 |
prompt = (
|
| 99 |
+
f"The following are JSON outputs from multiple hazard prediction models:\n\n"
|
| 100 |
+
f"{summary}\n\n"
|
| 101 |
+
f"A separate classifier predicted this injury severity: {zero_shot_injury}.\n\n"
|
| 102 |
+
f"Please analyze all JSON outputs and return:\n"
|
| 103 |
+
f"Cause of Accident: <natural language summary of the most likely cause>\n"
|
| 104 |
+
f"Degree of Injury: <Low | Medium | High>"
|
|
|
|
|
|
|
|
|
|
| 105 |
)
|
| 106 |
|
| 107 |
inputs = flan_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
|
| 108 |
with torch.no_grad():
|
| 109 |
output = flan_model.generate(
|
| 110 |
**inputs,
|
| 111 |
+
max_length=128,
|
| 112 |
temperature=0.5,
|
| 113 |
top_p=0.9,
|
| 114 |
do_sample=True
|