FrAnKu34t23 commited on
Commit
b06770b
·
verified ·
1 Parent(s): 67bff56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -54,19 +54,29 @@ def classify_injury_zero_shot(description):
54
 
55
  # === GENERATION FROM EACH MODEL ===
56
  def generate_single_model_output(model, tokenizer, prompt, max_length=300, temperature=0.7):
57
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
58
- with torch.no_grad():
59
- output = model.generate(
60
- **inputs,
61
- max_length=inputs["input_ids"].shape[1] + max_length,
62
- temperature=temperature,
63
- top_p=0.9,
64
- top_k=50,
65
- repetition_penalty=1.1,
66
- pad_token_id=tokenizer.eos_token_id,
67
- do_sample=True
68
- )
69
- return tokenizer.decode(output[0], skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
70
 
71
  # === ANALYSIS WITH FLAN-T5 ===
72
  def analyze_with_cpu_model(raw_outputs, zero_shot_injury):
 
54
 
55
  # === GENERATION FROM EACH MODEL ===
56
  def generate_single_model_output(model, tokenizer, prompt, max_length=300, temperature=0.7):
57
+ try:
58
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to("cpu")
59
+
60
+ with torch.no_grad():
61
+ output = model.generate(
62
+ **inputs,
63
+ max_length=inputs["input_ids"].shape[1] + max_length,
64
+ temperature=temperature,
65
+ top_p=0.9,
66
+ top_k=50,
67
+ repetition_penalty=1.1,
68
+ pad_token_id=tokenizer.eos_token_id,
69
+ do_sample=True
70
+ )
71
+
72
+ if output is not None and len(output) > 0:
73
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True).strip()
74
+ return decoded
75
+ else:
76
+ return "[No output was generated by the model.]"
77
+
78
+ except Exception as e:
79
+ return f"[Error generating output: {str(e)}]"
80
 
81
  # === ANALYSIS WITH FLAN-T5 ===
82
  def analyze_with_cpu_model(raw_outputs, zero_shot_injury):