Loomis Green commited on
Commit
c7e2760
·
1 Parent(s): cedf2c2

Downgrade transformers to 4.38.2 and add get_logits_warper monkeypatch

Browse files
Files changed (2) hide show
  1. app.py +25 -0
  2. requirements.txt +1 -1
app.py CHANGED
@@ -50,6 +50,31 @@ def _safe_isin(elements, test_elements, *args, **kwargs):
50
  torch.isin = _safe_isin
51
  print("Monkeypatch applied: torch.isin wrapper for int compatibility (robust)")
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
54
  from fastapi.responses import FileResponse, StreamingResponse, RedirectResponse
55
  from fastapi.middleware.cors import CORSMiddleware
 
50
  torch.isin = _safe_isin
51
  print("Monkeypatch applied: torch.isin wrapper for int compatibility (robust)")
52
 
53
+ # Monkeypatch transformers.generation.GenerationMixin._get_logits_warper
54
+ # Fixes: TypeError: GenerationMixin._get_logits_warper() missing 1 required positional argument: 'device'
55
+ # This is needed if transformers version requires 'device' but Coqui TTS calls it without.
56
+ try:
57
+ from transformers.generation import GenerationMixin
58
+ _old_get_logits_warper = GenerationMixin._get_logits_warper
59
+
60
+ def _safe_get_logits_warper(self, generation_config, device=None, **kwargs):
61
+ # Coqui calls it as: self._get_logits_warper(generation_config)
62
+ if device is None:
63
+ # Try to infer device from the model (self)
64
+ device = getattr(self, "device", "cpu")
65
+
66
+ # Try calling with device first (fix for newer transformers)
67
+ try:
68
+ return _old_get_logits_warper(self, generation_config, device=device, **kwargs)
69
+ except TypeError:
70
+ # Fallback for older transformers that might not accept 'device'
71
+ return _old_get_logits_warper(self, generation_config, **kwargs)
72
+
73
+ GenerationMixin._get_logits_warper = _safe_get_logits_warper
74
+ print("Monkeypatch applied: GenerationMixin._get_logits_warper")
75
+ except Exception as e:
76
+ print(f"Could not monkeypatch GenerationMixin: {e}")
77
+
78
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
79
  from fastapi.responses import FileResponse, StreamingResponse, RedirectResponse
80
  from fastapi.middleware.cors import CORSMiddleware
requirements.txt CHANGED
@@ -6,5 +6,5 @@ torch==2.4.0
6
  torchaudio==2.4.0
7
  numpy
8
  scipy
9
- transformers==4.42.0
10
  huggingface_hub
 
6
  torchaudio==2.4.0
7
  numpy
8
  scipy
9
+ transformers==4.38.2
10
  huggingface_hub