TiberiuCristianLeon commited on
Commit
fb5c2d4
·
verified ·
1 Parent(s): ad83d27

Update src/translate/Translate.py

Browse files
Files changed (1) hide show
  1. src/translate/Translate.py +11 -7
src/translate/Translate.py CHANGED
@@ -60,17 +60,20 @@ def gemma(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
60
  model = 'Gargaz/gemma-2b-romanian-better'
61
  # limit max_new_tokens to 150% of the requestValue
62
  max_new_tokens = int(len(requestValue) + len(requestValue) * 0.5)
63
- pipe = pipeline(
 
64
  "text-generation",
65
  model=model,
66
  device=-1,
67
  max_new_tokens=max_new_tokens, # Keep short to reduce verbosity
68
  do_sample=False # Use greedy decoding for determinism
69
- )
70
- output = pipe(messages, num_return_sequences=1, return_full_text=False)
71
- generated_text = output[0]["generated_text"]
72
- result = generated_text.split('\n', 1)[0].strip()
73
- return result, model
 
 
74
 
75
  def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
76
  # Load model directly
@@ -99,7 +102,8 @@ def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-bette
99
 
100
  outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
101
  response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
102
- return response.strip()
 
103
  except Exception as error:
104
  return error
105
 
 
60
  model = 'Gargaz/gemma-2b-romanian-better'
61
  # limit max_new_tokens to 150% of the requestValue
62
  max_new_tokens = int(len(requestValue) + len(requestValue) * 0.5)
63
+ try:
64
+ pipe = pipeline(
65
  "text-generation",
66
  model=model,
67
  device=-1,
68
  max_new_tokens=max_new_tokens, # Keep short to reduce verbosity
69
  do_sample=False # Use greedy decoding for determinism
70
+ )
71
+ output = pipe(messages, num_return_sequences=1, return_full_text=False)
72
+ generated_text = output[0]["generated_text"]
73
+ result = generated_text.split('\n', 1)[0] if '\n' in generated_text else generated_text
74
+ return result.strip()
75
+ except Exception as error:
76
+ return error
77
 
78
  def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'):
79
  # Load model directly
 
102
 
103
  outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
104
  response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
105
+ result = response.split('\n', 1)[0] if '\n' in response else response
106
+ return result.strip()
107
  except Exception as error:
108
  return error
109